PyTorch Lightning工具学习

【GiantPandaCV导语】Pytorch Lightning是在Pytorch基础上进行封装的库,为了让用户能够脱离PyTorch一些繁琐的细节,专注于核心代码的构建,提供了许多实用工具,可以让实验更加高效。本文将介绍安装方法、设计逻辑、转化的例子等内容。

PyTorch Lightning中提供了以下比较方便的功能:

  • multi-GPU训练
  • 半精度训练
  • TPU 训练
  • 将训练细节进行抽象,从而可以快速迭代

Pytorch Lightning

1. 简单介绍

PyTorch lightning 是为AI相关的专业的研究人员、研究生、博士等人群开发的。PyTorch就是William Falcon在他的博士阶段创建的,目标是让AI研究扩展性更强,忽略一些耗费时间的细节。

目前PyTorch Lightning库已经有了一定的影响力,star已经1w+,同时有超过1千多的研究人员在一起维护这个框架。

PyTorch Lightning库

同时PyTorch Lightning也在随着PyTorch版本的更新也在不停迭代。

版本支持情况

官方文档也有支持,正在不断更新:

官方文档

下面介绍一下如何安装。

2. 安装方法

Pytorch Lightning安装非常方便,推荐使用conda环境进行安装。

source activate you_env
pip install pytorch-lightning

或者直接用pip安装:

pip install pytorch-lightning

或者通过conda安装:

conda install pytorch-lightning -c conda-forge

3. Lightning的设计思想

Lightning将大部分AI相关代码分为三个部分:

  • 研究代码,主要是模型的结构、训练等部分。被抽象为LightningModule类。

  • 工程代码,这部分代码重复性强,比如16位精度,分布式训练。被抽象为Trainer类。

  • 非必要代码,这部分代码和实验没有直接关系,不加也可以,加上可以辅助,比如梯度检查,log输出等。被抽象为Callbacks类。

Lightning将研究代码划分为以下几个组件:

  • 模型
  • 数据处理
  • 损失函数
  • 优化器

以上四个组件都将集成到LightningModule类中,是在Module类之上进行了扩展,进行了功能性补充,比如原来优化器使用在main函数中,是一种面向过程的用法,现在集成到LightningModule中,作为一个类的方法。

4. LightningModule生命周期

这部分参考了https://zhuanlan.zhihu.com/p/120331610 和 官方文档 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

在这个模块中,将PyTorch代码按照五个部分进行组织:

  • Computations(init) 初始化相关计算
  • Train Loop(training_step) 每个step中执行的代码
  • Validation Loop(validation_step) 在一个epoch训练完以后执行Valid
  • Test Loop(test_step) 在整个训练完成以后执行Test
  • Optimizer(configure_optimizers) 配置优化器等

展示一个最简代码:

>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(28 * 28, 10)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_idx):
...         x, y = batch
...         y_hat = self(x)
...         loss = F.cross_entropy(y_hat, y)
...         return loss
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)

那么整个生命周期流程是如何组织的?

4.1 准备工作

这部分包括LightningModule的初始化、准备数据、配置优化器。每次只执行一次,相当于构造函数的作用。

  • __init__()(初始化 LightningModule )
  • prepare_data() (准备数据,包括下载数据、预处理等等)
  • configure_optimizers() (配置优化器)

4.2 测试 验证部分

实际运行代码前,会随即初始化模型,然后运行一次验证代码,这样可以防止在你训练了几个epoch之后要进行Valid的时候发现验证部分出错。主要测试下面几个函数:

  • val_dataloader()
  • validation_step()
  • validation_epoch_end()

4.3 加载数据

调用以下方法进行加载数据。

  • train_dataloader()
  • val_dataloader()

4.4 训练

  • 每个batch的训练被称为一个step,故先运行train_step函数。

  • 当经过多个batch, 默认49个step的训练后,会进行验证,运行validation_step函数。

  • 当完成一个epoch的训练以后,会对整个epoch结果进行验证,运行validation_epoch_end函数

  • (option)如果需要的话,可以调用测试部分代码:

    • test_dataloader()
    • test_step()
    • test_epoch_end()

5. 示例

以MNIST为例,将PyTorch版本代码转为PyTorch Lightning。

5.1 PyTorch版本训练MNIST

对于一个PyTorch的代码来说,一般是这样构建网络(源码来自PyTorch中的example库)。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

*pprp*

如果有帮助可以打赏一杯咖啡

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值