【PyTorch Lightning】简介

目录

一、Lightning 简约哲学

1.1 研究代码 (Research code)

1.2 工程代码 (Engineering code)

1.3 非必要代码 (Non-essential code)

二、典型的 AI 研究项目

三、生命周期

四、使用 Lightning 的好处


相关文章

【PyTorch Ligntning】快速上手简明指南 

【PyTorch Ligntning】如何将 PyTorch 组织为 Lightning

【PyTorch Lightning】1.0 正式发布:从 0 到 1 


PyTorch 已足够简单易用,但简单易用不等于方便快捷。特别是做大量实验时,很多东西都会变得复杂,代码也会变得庞大,此时就容易出错。

针对该问题,就有了 PyTorch Lightning。它可以重构你的 PyTorch 代码,抽出复杂重复部分,让你专注于核心的构建,让你的实验更快速更便捷地开展迭代。


一、Lightning 简约哲学

大部分的 DL/ML 代码都可以分为以下这三部分:

  • 研究代码 Research code
  • 工程代码 Engineering code
  • 非必要代码 Non-essential code

1.1 研究代码 (Research code)

这部分属于模型(神经网络)部分,一般处理模型的结构、训练等定制化部分。

在 Linghtning 中,这部分代码抽象为 LightningModule 类。

1.2 工程代码 (Engineering code)

这部分代码很重要的特点是:重复性强,如设置 early stopping、16位精度、GPUs 分布训练等。

在 Linghtning 中,这部分抽象为 Trainer 类。

1.3 非必要代码 (Non-essential code)

这部分代码有利于实验的进行,但和实验没有直接关系,甚至可以不用。如检查梯度、向 tensorboard 输出 log。

在 Linghtning 中,这部分抽象为 Callbacks 类。


二、典型的 AI 研究项目

在大多数研究项目中,研究代码 通常可以归纳到以下关键部分:

  • 模型
  • 训练/验证/测试 数据
  • 优化器
  • 训练/验证/测试 计算

上面已经提到,研究代码 在 Lightning 中,是抽象为 LightningModule 类;而该类与我们平时在 PyTorch 中使用的 torch.nn.Module 是一样的 (在原有代码中直接替换 Module 而不改其他代码也可以)。但不同的是,Lightning 围绕 torch.nn.Module 做了很多功能性的补充,把上面 4 个关键部分都囊括了进来。

如此设定的意义在于:我们的 研究代码 都是围绕 神经网络模型 来运行的,所以 Lightning 把这部分代码都集合在一个类里。所以接下来的介绍,都围绕 LightningModule 类来展开。


三、生命周期

为先呈现一个总体的概念,此处先介绍 LightningModule 中运行的生命流程。

以下所有函数都在 LightningModule 类中。

这部分是训练开始之后的执行 “一般(默认)顺序”

  • 首先是准备工作,包括初始化 LightningModule,准备数据 和 配置优化器。

这部分代码 只执行一次

1. `__init__()`(初始化 LightningModule )
2. `prepare_data()` (准备数据,包括下载数据、预处理等等)
3. `configure_optimizers()` (配置优化器)
  • 测试 “验证代码”。

提前来做的意义在于:不需要等待漫长的训练过程才发现验证代码有错。

这部分就是提前执行 “验证代码”,所以和下面的验证部分是一样的。

1. `val_dataloader()`
2. `validation_step()`
3. `validation_epoch_end()`
  • 开始加载dataloader,用来给训练加载数据

1. `train_dataloader()` 
2. `val_dataloader()` (如果你定义了)
  • 下面部分就是循环训练了,_step() 指按 batch 进行的部分;_epoch_step() 指所有 batch 执行完后 (一个 epoch) 要进行的部分。

# 循环训练与验证
1. `training_step()`
2. `validation_step()`
3. `validation_epoch_end()`
  • 最后训练完了,就要进行测试,但测试部分需手动调用 .test(),以避免误操作。

# 测试(需要手动调用)
1. `test_dataloader()` 
2. `test_step()`
3. `test_epoch_end()`

不难总结,在训练部分,主要包含三部分:_dataloader / _step / _epoch_end。Lightning 把训练的三部分抽象成三个函数,而用户只需要“填鸭式”地补充这三部分,就可以完成模型训练部分代码的编写。

为更清晰地展现这三部分的具体位置,以下用 PyTorch 实现方式 来展现其位置。

for epoch in epochs:
    for batch in train_dataloader:
        # train_step
        # ....
        # train_step
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
    for batch in val_dataloader:
        # validation_step
        # ....
        # validation_step
    
    # *_step_end
    # ....
    # *_step_end

四、使用 Lightning 的好处

  • 只需专注于 研究代码

不需要写一大堆的 .cuda() 和 .to(device),Lightning 会自动处理。如果要新建一个 tensor,可以使用 type_as 来使得新tensor 处于相同的处理器上。

def training_step(self, batch, batch_idx):
    x, y = batch

    # 把z放在和x一样的处理器上
    z = sample_noise()
    z = z.type_as(x)

此处需要注意的是,不是所有的在 LightningModule 的 tensor 都会被自动处理,而是只有从 Dataloader 里获取的 tensor 才会被自动处理,所以对于 transductive learning 的训练,最好自己写 Dataloader 的处理函数。 

  • 工程代码参数化

平时写模型训练时,这部分代码会不断重复,但又不得不做,比如 early stopping、精度调整、显存内存间数据转移等。这部分代码虽然不难,但减少这部分代码会使得 研究代码 更加清晰,整体也更加简洁。

下面是简单的展示,表示使用 LightningModule 建立好模型后,如何进行训练。

model = LightningModuleClass()
trainer = pl.Trainer(gpus="0",  # 用来配置使用什么GPU
                     precision=32, # 用来配置使用什么精度,默认是32
                     max_epochs=200 # 迭代次数
                     )

trainer.fit(model)  # 开始训练
trainer.test()  # 训练完之后测试

参考文献

https://zhuanlan.zhihu.com/p/120331610

https://pytorch-lightning.readthedocs.io/en/latest/

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch Lightning是一种轻量级的高级PyTorch封装,它使得训练神经网络更加容易、更加模块化。它提供了许多常用的功能,例如自动分布式训练、自动检查点、自动日志记录等等。下面是一个PyTorch Lightning的学习指南: 1. 先学习PyTorch基础知识:在学习PyTorch Lightning之前,您需要先学习PyTorch的基础知识,例如如何构建神经网络、如何训练模型等等。 2. 安装PyTorch Lightning:在安装PyTorch Lightning之前,您需要先安装PyTorch。然后可以通过pip安装PyTorch Lightning。 3. 了解PyTorch Lightning的核心概念:PyTorch Lightning的核心概念是“LightningModule”、“Trainer”和“DataModule”。LightningModule是您定义神经网络的地方,Trainer是您定义训练过程的地方,DataModule是您定义数据集的地方。 4. 编写您的第一个PyTorch Lightning程序:您可以从一个简单的例子开始,例如MNIST手写数字识别。在这个例子中,您可以定义一个LightningModule来构建神经网络,定义一个DataModule来加载数据集,然后定义一个Trainer来训练模型。 5. 学习如何自动分布式训练:PyTorch Lightning可以自动进行分布式训练,这意味着您可以在多个GPU或多台计算机上训练模型。您只需要在Trainer中设置一些参数即可。 6. 学习如何自动检查点和日志记录:PyTorch Lightning可以自动保存检查点和记录日志,这使得您可以在训练过程中随时恢复模型并查看训练指标。 7. 学习如何使用PyTorch Lightning扩展您的研究:PyTorch Lightning提供了许多扩展功能,例如自动优化器、自动批量大小调整、自动对抗性训练等等。您可以使用这些功能来扩展您的研究。 总之,PyTorch Lightning是一个非常强大的工具,可以使训练神经网络更加容易和高效。如果您想提高您的PyTorch技能并加快训练过程,请考虑学习PyTorch Lightning
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值