PyTorch Lightning工具学习

【GiantPandaCV导语】Pytorch Lightning是在Pytorch基础上进行封装的库,为了让用户能够脱离PyTorch一些繁琐的细节,专注于核心代码的构建,提供了许多实用工具,可以让实验更加高效。本文将介绍安装方法、设计逻辑、转化的例子等内容。

PyTorch Lightning中提供了以下比较方便的功能:

  • multi-GPU训练
  • 半精度训练
  • TPU 训练
  • 将训练细节进行抽象,从而可以快速迭代

Pytorch Lightning

1. 简单介绍

PyTorch lightning 是为AI相关的专业的研究人员、研究生、博士等人群开发的。PyTorch就是William Falcon在他的博士阶段创建的,目标是让AI研究扩展性更强,忽略一些耗费时间的细节。

目前PyTorch Lightning库已经有了一定的影响力,star已经1w+,同时有超过1千多的研究人员在一起维护这个框架。

PyTorch Lightning库

同时PyTorch Lightning也在随着PyTorch版本的更新也在不停迭代。

版本支持情况

官方文档也有支持,正在不断更新:

官方文档

下面介绍一下如何安装。

2. 安装方法

Pytorch Lightning安装非常方便,推荐使用conda环境进行安装。

source activate you_env
pip install pytorch-lightning

或者直接用pip安装:

pip install pytorch-lightning

或者通过conda安装:

conda install pytorch-lightning -c conda-forge

3. Lightning的设计思想

Lightning将大部分AI相关代码分为三个部分:

  • 研究代码,主要是模型的结构、训练等部分。被抽象为LightningModule类。

  • 工程代码,这部分代码重复性强,比如16位精度,分布式训练。被抽象为Trainer类。

  • 非必要代码,这部分代码和实验没有直接关系,不加也可以,加上可以辅助,比如梯度检查,log输出等。被抽象为Callbacks类。

Lightning将研究代码划分为以下几个组件:

  • 模型
  • 数据处理
  • 损失函数
  • 优化器

以上四个组件都将集成到LightningModule类中,是在Module类之上进行了扩展,进行了功能性补充,比如原来优化器使用在main函数中,是一种面向过程的用法,现在集成到LightningModule中,作为一个类的方法。

4. LightningModule生命周期

这部分参考了https://zhuanlan.zhihu.com/p/120331610 和 官方文档 https://pytorch-lightning.readthedocs.io/en/latest/trainer.html

在这个模块中,将PyTorch代码按照五个部分进行组织:

  • Computations(init) 初始化相关计算
  • Train Loop(training_step) 每个step中执行的代码
  • Validation Loop(validation_step) 在一个epoch训练完以后执行Valid
  • Test Loop(test_step) 在整个训练完成以后执行Test
  • Optimizer(configure_optimizers) 配置优化器等

展示一个最简代码:

>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
...     def __init__(self):
...         super().__init__()
...         self.l1 = torch.nn.Linear(28 * 28, 10)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_idx):
...         x, y = batch
...         y_hat = self(x)
...         loss = F.cross_entropy(y_hat, y)
...         return loss
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)

那么整个生命周期流程是如何组织的?

4.1 准备工作

这部分包括LightningModule的初始化、准备数据、配置优化器。每次只执行一次,相当于构造函数的作用。

  • __init__()(初始化 LightningModule )
  • prepare_data() (准备数据,包括下载数据、预处理等等)
  • configure_optimizers() (配置优化器)

4.2 测试 验证部分

实际运行代码前,会随即初始化模型,然后运行一次验证代码,这样可以防止在你训练了几个epoch之后要进行Valid的时候发现验证部分出错。主要测试下面几个函数:

  • val_dataloader()
  • validation_step()
  • validation_epoch_end()

4.3 加载数据

调用以下方法进行加载数据。

  • train_dataloader()
  • val_dataloader()

4.4 训练

  • 每个batch的训练被称为一个step,故先运行train_step函数。

  • 当经过多个batch, 默认49个step的训练后,会进行验证,运行validation_step函数。

  • 当完成一个epoch的训练以后,会对整个epoch结果进行验证,运行validation_epoch_end函数

  • (option)如果需要的话,可以调用测试部分代码:

    • test_dataloader()
    • test_step()
    • test_epoch_end()

5. 示例

以MNIST为例,将PyTorch版本代码转为PyTorch Lightning。

5.1 PyTorch版本训练MNIST

对于一个PyTorch的代码来说,一般是这样构建网络(源码来自PyTorch中的example库)。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 
PyTorch Lightning是一种轻量级的高级PyTorch封装,它使得训练神经网络更加容易、更加模块化。它提供了许多常用的功能,例如自动分布式训练、自动检查点、自动日志记录等等。下面是一个PyTorch Lightning学习指南: 1. 先学习PyTorch基础知识:在学习PyTorch Lightning之前,您需要先学习PyTorch的基础知识,例如如何构建神经网络、如何训练模型等等。 2. 安装PyTorch Lightning:在安装PyTorch Lightning之前,您需要先安装PyTorch。然后可以通过pip安装PyTorch Lightning。 3. 了解PyTorch Lightning的核心概念:PyTorch Lightning的核心概念是“LightningModule”、“Trainer”和“DataModule”。LightningModule是您定义神经网络的地方,Trainer是您定义训练过程的地方,DataModule是您定义数据集的地方。 4. 编写您的第一个PyTorch Lightning程序:您可以从一个简单的例子开始,例如MNIST手写数字识别。在这个例子中,您可以定义一个LightningModule来构建神经网络,定义一个DataModule来加载数据集,然后定义一个Trainer来训练模型。 5. 学习如何自动分布式训练:PyTorch Lightning可以自动进行分布式训练,这意味着您可以在多个GPU或多台计算机上训练模型。您只需要在Trainer中设置一些参数即可。 6. 学习如何自动检查点和日志记录:PyTorch Lightning可以自动保存检查点和记录日志,这使得您可以在训练过程中随时恢复模型并查看训练指标。 7. 学习如何使用PyTorch Lightning扩展您的研究:PyTorch Lightning提供了许多扩展功能,例如自动优化器、自动批量大小调整、自动对抗性训练等等。您可以使用这些功能来扩展您的研究。 总之,PyTorch Lightning是一个非常强大的工具,可以使训练神经网络更加容易和高效。如果您想提高您的PyTorch技能并加快训练过程,请考虑学习PyTorch Lightning
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

*pprp*

如果有帮助可以打赏一杯咖啡

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值