pytorch中ignite的学习笔记

ignite 0.2.0 模块

ignite 是pytorch中训练模型的高级库
参考网址
https://blog.csdn.net/qq_29257201/article/details/94454657
https://pytorch.org/ignite/quickstart.html

概念介绍

在ignite中主要的是三个概念:Engine,Events and Handles,States

  1. Engine: 是ignite库中最重要的类,它主要是提供了trainer和evaluater;
  2. Events and Handles:
    为了提高engine的灵活性引入Events and Handles系统促进run中每个步骤的交互性:
  • engine的开始和结束
  • epoch 的开始和结束
  • batch iteration的开始和结束
  1. States:存储process_function,current epoch,iteration和其他有用信息的输出
  • engine.state.epoch: epochs the engine已完成的次数 . 初始化为 0.
  • engine.state.max_epochs: epochs要运行的次数. 初始化为 1.
  • engine.state.iteration: 已完成的迭代次数. 初始化为 0.
  • engine.state.output: the process_function defined for the Engine的输出.

训练框架

  • 构建class Net(nn.Module)网络结构
  • 定义函数get_data_loaders(train_batch_size, val_batch_size), 返回train_loader,val_loader
  • 创建函数create_summary_writer(model,data_loader, dog_dir) ,返回的是一个SummaryWriter对象
  • 定义run(train_batch_size,val_batch_size, epochs, lr, momentum, log_interval, log_dir)
  • 运行函数

例子与解析

完整代码:

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'accuracy': Accuracy(),
                                            'nll': Loss(loss)
                                            })

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(trainer):
    print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics['accuracy'], metrics['nll']))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics['accuracy'], metrics['nll']))

trainer.run(train_loader, max_epochs=100)

在前4行中,我们定义了模型,训练和验证数据集(如torch.utils.data.DataLoader),优化器和损失函数:

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
loss = torch.nn.NLLLoss()

接下来我们定义trainer and evaluator engines。Ignite的主要组成部分是对Engine训练循环的抽象。开始使用引擎很简单,构造函数只需要一件事:

  • update_function: a function that receives the engine and a batch and have a role to update your model.

在上面的例子中,我们使用辅助方法create_supervised_trainer()create_supervised_evaluator()

trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model,
                                        metrics={
                                            'accuracy': Accuracy(),
                                            'nll': Loss(loss)
                                            })

Engine允许在运行期间触发的各种事件上添加处理程序。触发事件时,将执行附加的处理程序(函数)。因此,为了记录目的,我们添加了一个在每次迭代后执行的函数:

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
    print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, len(train_loader), engine.state.output))

当一个epoch结束时,我们需要计算training和validation指标。为此,我们可以在train_loader和val_loader上运行先前定义的求值程序。因此,我们在epoch完成事件上附加了两个额外的处理程序:

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics['accuracy'], metrics['nll']))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(engine.state.epoch, metrics['accuracy'], metrics['nll']))

最后,我们在训练数据集上启动engine并在100个epochs内运行它:

trainer.run(train_loader, max_epochs=100)
  • 1
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值