【GiantPandaCV导语】Pytorch Lightning是在Pytorch基础上进行封装的库,为了让用户能够脱离PyTorch一些繁琐的细节,专注于核心代码的构建,提供了许多实用工具,可以让实验更加高效。本文将介绍安装方法、设计逻辑、转化的例子等内容。
PyTorch Lightning中提供了以下比较方便的功能:
- multi-GPU训练
- 半精度训练
- TPU 训练
- 将训练细节进行抽象,从而可以快速迭代
1. 简单介绍
PyTorch lightning 是为AI相关的专业的研究人员、研究生、博士等人群开发的。PyTorch就是William Falcon在他的博士阶段创建的,目标是让AI研究扩展性更强,忽略一些耗费时间的细节。
目前PyTorch Lightning库已经有了一定的影响力,star已经1w+,同时有超过1千多的研究人员在一起维护这个框架。
同时PyTorch Lightning也在随着PyTorch版本的更新也在不停迭代。
官方文档也有支持,正在不断更新:
下面介绍一下如何安装。
2. 安装方法
Pytorch Lightning安装非常方便,推荐使用conda环境进行安装。
source activate you_env
pip install pytorch-lightning
或者直接用pip安装:
pip install pytorch-lightning
或者通过conda安装:
conda install pytorch-lightning -c conda-forge
3. Lightning的设计思想
Lightning将大部分AI相关代码分为三个部分:
-
研究代码,主要是模型的结构、训练等部分。被抽象为LightningModule类。
-
工程代码,这部分代码重复性强,比如16位精度,分布式训练。被抽象为Trainer类。
-
非必要代码,这部分代码和实验没有直接关系,不加也可以,加上可以辅助,比如梯度检查,log输出等。被抽象为Callbacks类。
Lightning将研究代码划分为以下几个组件:
- 模型
- 数据处理
- 损失函数
- 优化器
以上四个组件都将集成到LightningModule类中,是在Module类之上进行了扩展,进行了功能性补充,比如原来优化器使用在main函数中,是一种面向过程的用法,现在集成到LightningModule中,作为一个类的方法。
4. LightningModule生命周期
这部分参考了https://zhuanlan.zhihu.com/p/120331610 和 官方文档 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html
在这个模块中,将PyTorch代码按照五个部分进行组织:
- Computations(init) 初始化相关计算
- Train Loop(training_step) 每个step中执行的代码
- Validation Loop(validation_step) 在一个epoch训练完以后执行Valid
- Test Loop(test_step) 在整个训练完成以后执行Test
- Optimizer(configure_optimizers) 配置优化器等
展示一个最简代码:
>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
... def __init__(self):
... super().__init__()
... self.l1 = torch.nn.Linear(28 * 28, 10)
...
... def forward(self, x):
... return torch.relu(self.l1(x.view(x.size(0), -1)))
...
... def training_step(self, batch, batch_idx):
... x, y = batch
... y_hat = self(x)
... loss = F.cross_entropy(y_hat, y)
... return loss
...
... def configure_optimizers(self):
... return torch.optim.Adam(self.parameters(), lr=0.02)
那么整个生命周期流程是如何组织的?
4.1 准备工作
这部分包括LightningModule的初始化、准备数据、配置优化器。每次只执行一次,相当于构造函数的作用。
__init__()
(初始化 LightningModule )prepare_data()
(准备数据,包括下载数据、预处理等等)configure_optimizers()
(配置优化器)
4.2 测试 验证部分
实际运行代码前,会随即初始化模型,然后运行一次验证代码,这样可以防止在你训练了几个epoch之后要进行Valid的时候发现验证部分出错。主要测试下面几个函数:
val_dataloader()
validation_step()
validation_epoch_end()
4.3 加载数据
调用以下方法进行加载数据。
train_dataloader()
val_dataloader()
4.4 训练
-
每个batch的训练被称为一个step,故先运行train_step函数。
-
当经过多个batch, 默认49个step的训练后,会进行验证,运行validation_step函数。
-
当完成一个epoch的训练以后,会对整个epoch结果进行验证,运行validation_epoch_end函数
-
(option)如果需要的话,可以调用测试部分代码:
- test_dataloader()
- test_step()
- test_epoch_end()
5. 示例
以MNIST为例,将PyTorch版本代码转为PyTorch Lightning。
5.1 PyTorch版本训练MNIST
对于一个PyTorch的代码来说,一般是这样构建网络(源码来自PyTorch中的example库)。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32,