pytorch之Bootstrap简单介绍

根据 bootstrap.pytorch官方翻译的

1 简介

Bootstrap是启动深度学习项目的高级框架。它旨在通过提供只关注数据集和模型的强大工作流来加速研究项目和原型开发。

1.1 下载

pip install bootstrap.pytorch

2 内容

在这里插入图片描述

bootstrap包含Engine(启动引擎),Dataset(数据集),Model(模型),Options(选择),Logger(日志),View(评估可视化)模块,具体内容如下:
在这里插入图片描述

2.1 Engine

Boostrap的核心是bootstrap.engines.Enginine类,它包含在所提供的数据集、模型、优化器和视图上循环给定次数的训练和评估方法
Engine包含可以在训练和评估过程中随时触发的hook调用。每个engine components都能够注册hook函数。

engine.register_hook('train_on_update', display_learning_rate)

此外,在Engine的初始化过程中,我们严重依赖工factory提供最佳的自定义:

engine = engines.factory()

# dataset是一个字典,它包含所有需要的按模式索引的数据集 {train, eval}
engine.dataset = datasets.factory(engine)

# model包括网络network、标准criterion和度量metric
engine.model = models.factory(engine)

engine.optimizer = optimizers.factory(engine.model, engine)
# view将保存一个view.html在实验目录
# view:用一些漂亮的图形和曲线来监控训练
engine.view = views.factory(engine)

在这里插入图片描述

2.1.1 eval:

eval():启动评估程序,

eval_epoch(model, dataset, epoch, mode='eval', logs_json=True):启动一个时期的评估程序
hooks列表(默认mode=‘eval’):

  • mode_on_start_epoch:在一个epoch的计算过程之前
  • mode_on_start_batch:在一个batch的计算过程之前
  • mode_on_forward:在模型的forward之后
  • mode_on_print:打印到终端后
  • mode_on_end_batch:批处理的评估过程结束
  • mode_on_end_epoch:在log .json中保存日志之前
  • mode_on_flush: 一个时期评估的计算过程结束
    Returns:模型的所有标量输出的平均值,按输出名进行索引
    return type:out(dict)

generate_view():通过对self.view.generate()的异步调用生成一个view.html

is_best(out,saving_criteria)::验证最后一个模型是否最适合特定的保存条件
参数:out(dict):根据输出名称索引的模型的所有标量输出的平均值,saving_criteria(str)
案例:
out = {
‘loss’: 0.2,
‘acctop1’: 87.02
}
engine.is_best(out, ‘loss:min’)

load(dir_logs, name, model, optimizer):加载一个检查点

resume() :使用bootstrap.lib.options.Options恢复检查点
save(dir_logs, name, model, optimizer):保存一个检查点

train():开始训练一个程序
hooks列表:

  • train_on_start:在完整的培训程序之前

train_epoch(model, dataset, optimizer, epoch, mode='train'):启动一个时期的训练程序

bootstrap.engines.logger.LoggerEngine类:LoggerEngine与Engine类似。唯一的区别是一个更强大的is_best方法。它能够查看包含按名称索引的所有日志变量的列表的记录器字典。

2.2 Dataset

在bootstrap中,dataset将与数据加载器融合,数据集的唯一要求是类型为torch.utils.data.Dataset,并具有引擎调用的make_batch_loader方法

batch_loader = dataset.make_batch_loader()
for i, batch in enumerate(batch_loader):
     # training or evaluation(循环执行训练或者评估)

2.3 Model

模型是一个torch.nn.Module由一个网络(例如:resnet 152)组成,在前向传播过程中,模型model负责将batch传递到networkcriterionmetric。它输出一个损失值loss

out = model(batch)
out['loss'].backward()

为了灵活和支持复杂的工作流,损失函数criterions和度量metric在训练过程中可能会根据执行模式(train或eval)或数据集分割(train、val、test)而有所不同。

2.3.1 bootstrap.models.model.Model类

Model包含一个网络network、两个标准criterions(train、eval)和两个度量metrics。

eval():激活评估模式

forward(batch): 准备batch并将其输入网络、标准和度量。
return:输出一个字典

prepare_batch(batch):准备一个包含两个函数的batch:cuda_tfdetach_tf(仅在eval模式下)

train():激活训练模式

2.3.2 bootstrap.models.model.DefaultModel类

依赖于工厂调用的模型扩展

2.3.3 bootstrap.models.model.SimpleModel

修改forward函数的DefaultModel扩展

forward(batch):对网络的转发调用使用batch[’ data ']而不是batch

2.4 Options

options类是boostrap程序的核心组件之一。它管理实验的所有(超)参数,并将它们存储在一个yaml文件中,通过解析该文件来创建默认的命令行参数。
options可以很容易地被命令行参数覆盖,从而方便超参数搜索。

python –m bootstrap.run -o mnist/options/sgd.yaml #执行这个yaml文件设置参数目录
        --exp.dir logs/example #结果日志输出到这里个文件
        --model.metric.topk 1 2 3 #sgd.yaml文件中的metric.topk的值被修改为1 2 3

options的格式如下:

exp:
  dir: logs/mnist
  resume:
dataset:
  import: mnist.datasets.factory
  name: mnist
  dir: data/mnist
  train_split: train
  eval_split: val
  nb_threads: 4
  batch_size: 64
model:
  name: simple
    network:
      import: mnist.models.networks.factory
      name: lenet
    criterion:
      name: nll
    metric:
      name: accuracy
      topk: [1,5]
optimizer:
  name: sgd
  lr: 0.01
engine:
  name: default
  nb_epochs: 10
  saving_criteria:
    - loss:min
    - acctop1:max
view:
  - logs:train_epoch.loss
  - logs:eval_epoch.acctop1

您可以使用singleton bootstrap.lib.options.Options类访问每个参数。

opt = Options()
print(opt['engine'])      # 字典类型
print(opt['engine.name']) # 字符串类型

一旦加载并可能被覆盖,这些选项将作为YAML文件存储在实验目录中,从而使实验易于复制和恢复。

python –m bootstrap.run
    -o logs/mnist/options.yaml
    --exp.resume best_acctop1

2.5 Logger

与实验相关的所有内容都存储在同一个目录中(日志logs、检查点checkpoints、可视化visualizations、选项options等)。

ls logs/mnist
  ckpt_last_engine.pth.tar
  ckpt_last_model.pth.tar
  ckpt_last_optimizer.pth.tar
  ckpt_best_acctop1_engine.pth.tar
  ckpt_best_acctop1_model.pth.tar
  ckpt_best_acctop1_optimizer.pth.tar
  logs.json
  logs.txt
  options.yaml
  view.html

singleton bootstrap.lib.logger.Logger类可以使用简单的键值接口记录任何变量。

Logger(dir_logs='logs/mnist')
Logger().log_value('train_epoch.epoch', epoch)
Logger().log_value('train_epoch.mean_acctop1', mean_acctop1)
Logger().flush() # write the logs.json

Logger类是一个单例。它包含用于在键值字典中记录变量的所有实用程序。它也可以被看作是打印函数的替代。

Logger(dir_logs='logs/mnist')
Logger().log_value('train_epoch.epoch', epoch)
Logger().log_value('train_epoch.mean_acctop1', mean_acctop1)
Logger().flush() # write the logs.json

Logger()("Launching training procedures") # written to logs.txt
> [I 2018-07-23 18:58:31] ...trap/engines/engine.py.80: Launching training procedures

2.6 View

每个训练和评估阶段结束时,bootstrap.views.view。视图类从logs.json加载数据,并生成可视化。默认情况下,bootstrap.pytorch依靠plotly库(由vizdom使用)在view.html文件中使用javascript创建动态绘图,但也可以生成tensorboard文件。

  • 3
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值