PyTorch-Lightning
- 介绍
- 安装
- 实用功能 (Trainer参数详解)
-
- 自动获取Batch Size - Automatic Batch Size Finder
- 自动获取初始学习率 - Automatic Learning Rate Finder
- 重新加载数据 - Reload DataLoaders Every Epoch
- 回调函数 - Callbacks
- 展示网络信息 - Weights Summary
- 进度条 - Progress Bar
- 训练以及测试循环 - Training and Eval Loops
- 单GPU以及多GPUs训练 - Training on GPUs
- 进阶分布式训练 - Advanced distributed training
- 代码测试 - Debugging
- 梯度累积 - Accumulating Gradients
- 混合精度训练 - Mixed Precision Training
- 示例代码
介绍
来源:http://blog.itpub.net/31555081/viewspace-2698296/
PyTorch很容易使用,可以用来构建复杂的AI模型。但是一旦研究变得复杂,并且将诸如多GPU训练,16位精度和TPU训练之类的东西混在一起,用户很可能会写出有bug的代码。
PyTorch Lightning完全解决了这个问题。Lightning会构建您的PyTorch代码,以便抽象出训练的详细信息。这使得AI研究可扩展并且可以快速迭代。
PyTorch Lightning是NYU和FAIR(Facebook AI research)为从事AI研究的专业研究人员和博士生所创建的。
安装
conda activate your-env
pip install pytorch-lightning
实用功能 (Trainer参数详解)
来源:Lightning官方教程 和 官方文档
自动获取Batch Size - Automatic Batch Size Finder
auto_scale_batch_size
Batch Size一般会对模型的训练结果有影响,一般越大的batch size模型训练的结果会越好。有时候我们不知道自己的模型在当前的机器上最多能用多大的batch size,这时候通过Lightning Trainer的这个flag就可以帮助我们找到最大的batch size。
model = ...
# 设置为True,Trainer就会依次尝试用2的幂次方的batch size,直到超出内存
trainer = pl.Trainer(auto_scale_batch_size=True)
trainer.fit(model)
# 设置为'binsearch',Trainer会用Binary Search的方式帮你找到最大的Batch Size
trainer = pl.Trainer(auto_scale_batch_size='binsearch')
trainer.tune(model)
# 注意:如果要用这个功能,在Module里面的__init__()函数中要有:
self.batch_size = batch_size
# 或者在__init__()里面调用:
self.save_hyperparameters()
自动获取初始学习率 - Automatic Learning Rate Finder
auto_lr_find
学习率learning rate是很重要的一个超参,选取一个合适的初始学习率也是很重要的,Lightning提供了这个有用的flag。
(15.11.2020)目前只支持单优化器(Optimizer),预计在未来的几个月内会支持多优化器
import pytorch_lightning as pl
model = ...
# 可以直接设置为True,Trainer会自动用不同的学习率运行model,然后画出loss和学习率的曲线,帮你找到最合适的学习率
trainer = pl.Trainer(auto_lr_find=True)
trainer.tune(model)
print(model.learning_rate)
# 有时候我们会在model中给学习率起其他的名字,比如:
self.my_learning_rate = lr
# 这个时候我们可以用变量名直接设置auto_lr_find:
trainer = pl.Trainer(auto_lr_find='my_learing_rate')
# 开始寻找合适的学习率
lr_finder = trainer.tuner.lr_find(model)
# 展示loss和学习率的曲线
fig = lr_finder.plot(suggest=True)
fig.show()
# 设置为推荐的学习率
model.hyparams.learning_rate = lr_finder.suggestion()
# 开始训练
model.fit(model, train_loader, val_loader)
重新加载数据 - Reload DataLoaders Every Epoch
reload_dataloaders_every_epoch
一般数据只会在一开始加载一次,即在epochs的循环前面加载一次,然后每个循环都会shuffle之类的(如果你设置shuffle为True的话)。有时候我们的数据在训练过程中是会改变的,这个时候我们就需要在每个epoch都要再加载一次数据,Lightning就提供了这样一个flag,将其设置为True即可。
# 相当于:
# if False (default)
train_loader = model.train_dataloader()
for epoch in epochs:
for batch in train_loader:
...
# if True
for epoch in epochs:
train_loader = model.train_dataloader()
for batch in train_loader:
回调函数 - Callbacks
callbacks
回调函数 (Callbacks) 在机器学习中也是很重要的工具,一般可以用来进行模型的断点断续,模型权重的存储,提早停止 (Early stop),动态调整训练参数以及tensorboard之类的训练可视化等等。Lightning也支持非常灵活的Callbacks,只需要把Callbacks放进flag:callbacks中即可。Lightning提供了一些built-in的callbacks,同样也支持自定义callbacks类,所以非常灵活。
官方的callbacks说明
# 自定义Callback类
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_init_start(self, trainer):
print('Starting to init trainer!')
def on_init_end(self, trainer):
print('trainer is init now')
def on_train_end(self, trainer, pl_module):
print('do something when training ends')
trainer = Trainer(callbacks=[MyPrintingCallback()])
# 使用built-in的Callbacks
from pytorch_lightning.callbacks import EarlyStopping
# 可以直接使用默认的Callbacks
trainer = pl.Trainer(callbacks