GTC Japan 2016 Chainer feature introduction
-
Upload
kenta-oono -
Category
Technology
-
view
135 -
download
1
Transcript of GTC Japan 2016 Chainer feature introduction
Chainer feature introductionTraining and dataset abstraction
5th Oct. 2016GTC Japan @ Tokyo
Preferred Networks, Inc.Kenta Oono
Trainer and Dataset abstraction
• New feature from v1.11.0
ü Free users from implementing training loops by themselves.ü Support most of typical training procedures.ü Easy to customize and extend.
• Note: We can also write manually training loops without this feature, as we did in the examples of the previous versions.
2
Target Link
Dataset
Optimizer
Iterator
Main modules
• Dataset, Iterator: extract mini batches by iterating over datasets• Trainer, Updater, Extension: customize the training loop with low cost• Reporter: to collect statistics from inside of the models
3
TrainerExtensionExtensionExtension
Updater OptimizerOptimizer
Target LinkTarget Link
IteratorIterator
DatasetDataset
We often use only one optimizer and one dataset. This diagram shows a general case.
MNIST classification by MLP with Trainer
class MLP(Link):def __int__(self):
super(MLP, self).__init__(
l1=Linear(784, 1000),
l2=Linear(1000, 1000),
l3=Linear(1000, 10))
def __call__(x):h1 = F.relu(self.l1(x))h2 = F.relu(self.l2(l1))return self.l3(h2)
Linear l1x
W bias
5
ReLU
Linear l2h1
W bias
ReLU
Linear l3h2
W bias
4
# Prepare datasets and their iterators train, test = get_mnist()train_iter = SerialIterator(train, 128)test_iter = SerialIterator(test, 128, repeat=False,
shuffle=False)
# Prepare links and their optimizersmodel = L.Classifier(MLP()) optimizer = Adam()optimizer.setup(model)
# Prepare trainerupdater = StandardUpdater(train_iter, optimizer)trainer = Trainer(updater, (10, 'epoch'))
5
# Add extensions to augment trainertrainer.extend(Evaluator(test_iter, model))
trainer.extend(dump_graph('main/loss'))
trainer.extend(snapshot())
trainer.extend(LogReport())
trainer.extend(PrintReport('epoch', 'main/accuracy',
'validation/main/accuracy']))
trainer.extend(ProgressBar())
# Executetrainer.run()
6
Pseudo code of training loop abstraction
For each extension e:Invoke e if specified
Until stop_trigger is fired:
Invoke updater
for each extension e:if e’s trigger is fired:
Invoke e
For each extension e:
Finalize eFinalize updater
7
• Trainer has stop trigger to determine when to stop the training loop
• Each extension have a trigger to determine when to invoke
Trainer-related modules
• Updater– Fetch a mini-batch using Iterator, and update parameters using
Optimizer– You can customize the update routine– Built-in updater: StandardUpdater, ParallelUpdater
• Extension– It adds an extra routine to the training loop– Basic extensions are built-in:
Evaluator, LogReport, PrintReport, ProgressBarsnapshot, snapshot_object, ExponentialDecay, LinearShift, dump_graph
– You can write your own extensions
9