GTC Japan 2016 Chainer feature introduction

10
Chainer feature introduction Training and dataset abstraction 5 th Oct. 2016 GTC Japan @ Tokyo Preferred Networks, Inc. Kenta Oono [email protected]

Transcript of GTC Japan 2016 Chainer feature introduction

Chainer feature introductionTraining and dataset abstraction

5th Oct. 2016GTC Japan @ Tokyo

Preferred Networks, Inc.Kenta Oono

[email protected]

Trainer and Dataset abstraction

• New feature from v1.11.0

ü Free users from implementing training loops by themselves.ü Support most of typical training procedures.ü Easy to customize and extend.

• Note: We can also write manually training loops without this feature, as we did in the examples of the previous versions.

2

Target Link

Dataset

Optimizer

Iterator

Main modules

• Dataset, Iterator: extract mini batches by iterating over datasets• Trainer, Updater, Extension: customize the training loop with low cost• Reporter: to collect statistics from inside of the models

3

TrainerExtensionExtensionExtension

Updater OptimizerOptimizer

Target LinkTarget Link

IteratorIterator

DatasetDataset

We often use only one optimizer and one dataset. This diagram shows a general case.

MNIST classification by MLP with Trainer

class MLP(Link):def __int__(self):

super(MLP, self).__init__(

l1=Linear(784, 1000),

l2=Linear(1000, 1000),

l3=Linear(1000, 10))

def __call__(x):h1 = F.relu(self.l1(x))h2 = F.relu(self.l2(l1))return self.l3(h2)

Linear l1x

W bias

ReLU

Linear l2h1

W bias

ReLU

Linear l3h2

W bias

4

# Prepare datasets and their iterators train, test = get_mnist()train_iter = SerialIterator(train, 128)test_iter = SerialIterator(test, 128, repeat=False,

shuffle=False)

# Prepare links and their optimizersmodel = L.Classifier(MLP()) optimizer = Adam()optimizer.setup(model)

# Prepare trainerupdater = StandardUpdater(train_iter, optimizer)trainer = Trainer(updater, (10, 'epoch'))

5

# Add extensions to augment trainertrainer.extend(Evaluator(test_iter, model))

trainer.extend(dump_graph('main/loss'))

trainer.extend(snapshot())

trainer.extend(LogReport())

trainer.extend(PrintReport('epoch', 'main/accuracy',

'validation/main/accuracy']))

trainer.extend(ProgressBar())

# Executetrainer.run()

6

Pseudo code of training loop abstraction

For each extension e:Invoke e if specified

Until stop_trigger is fired:

Invoke updater

for each extension e:if e’s trigger is fired:

Invoke e

For each extension e:

Finalize eFinalize updater

7

• Trainer has stop trigger to determine when to stop the training loop

• Each extension have a trigger to determine when to invoke

8

Trainer-related modules

• Updater– Fetch a mini-batch using Iterator, and update parameters using

Optimizer– You can customize the update routine– Built-in updater: StandardUpdater, ParallelUpdater

• Extension– It adds an extra routine to the training loop– Basic extensions are built-in:

Evaluator, LogReport, PrintReport, ProgressBarsnapshot, snapshot_object, ExponentialDecay, LinearShift, dump_graph

– You can write your own extensions

9

Dataset-related modules

• Dataset is just a sequence of data points (a.k.a. examples)• Iterator defines how to iterate over the dataset• Built-in iterators:

– SerialIterator– MultiprocessIterator

10