chainer入门

本文介绍了Chainer的神经网络结构,包括dataset、iterator、optimizer和model。重点讲解了Updater的作用以及如何创建Trainer对象。核心部分是Extensions模块,它允许在训练过程中进行操作,如评估模型、记录进程和保存模型。通过添加Extension到Trainer并配置trigger,可以灵活控制训练流程。同时提到了@make_extension装饰器的使用。
摘要由CSDN通过智能技术生成

chainer是比较早期的神经网络结构,2015年发布。目前用pytorch的还是比较多,但是chainer实现的辅助工具依然是非常实用和受欢迎的。

一、chainer结构

dataset,iterator,optimizer and model是大家所熟知的,就不细讲了 

Updater:

Now that we have the training iterator and optimizer set up, we link them both together into the updater. The updater uses the minibatches from the iterator, does the forward and backward processing of the model, and updates the parameters of the model according to the optimizer.

49# Create the updater, using the optimizer
50updater = training.StandardUpdater(train_iter, optimizer, device=-1)

Finally we create a Trainer object. The trainer processes minibatches using the updater。

52# Set up a trainer
53trainer = training.Trainer(updater, (50, 'epoch'), out='result')

Extensions:

(这是最主要的部分)

这个模块可以在训练的中途进行操作,比如训练多少iterations或者多少epoches之后,甚至是steps之后。主要是用来评估模型,记录模型进程,或者用来保存模型的。

功能如下:

First, use the testing iterator defined above for an Evaluator extension to the trainer to provide test scores.

#54 Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=-1))

Save a computational graph from loss variable at the first iteration:

#57 Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.DumpGraph('main/loss'))

Take a snapshot of the trainer object every 20 epochs.

trainer.extend(extensions.snapshot(), trigger=(20, 'epoch'))

Write a log of evaluation statistics for each epoch.

# 63 Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

Save two plot images to the result directory.

# Save two plot images to the result dir
trainer.extend(
    extensions.PlotReport(['main/loss', 'validation/main/loss'],
                          'epoch', file_name='loss.png'))
trainer.extend(
    extensions.PlotReport(
        ['main/accuracy', 'validation/main/accuracy'],
        'epoch', file_name='accuracy.png'))

Print selected entries of the log to standard output.

# Print selected entries of the log to stdout
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss',
     'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

Main Loop:

#  Run the training
trainer.run()

What is trainer Extension?

By adding an Extension to a Trainer using the extend() method, the Extension will be called according to the schedule specified by using a trigger object.

The Trainer object contains all information used in a training loop, e.g., models, optimizers, updaters, iterators, and datasets, etc. This makes it possible to change settings such as the learning rate of an optimizer.

用好trigger就可以了

trainer.extend(lr_drop, trigger=(10, 'epoch'))

使用装饰器@make_extension

@training.make_extension(trigger=(10, 'epoch'))
def lr_drop(trainer):
    trainer.updater.get_optimizer('main').lr *= 0.1

详细请看chainer官网

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值