简介
Pytorch Lightning是基于Pytorch的一个框架,能够很好地组织用Pytorch实现的机器学习代码。它的强大之处在于将科研想法的实现和繁琐的工程代码解耦,让研究者只用关注于科研想法的实现,而工程代码则由Lightning模块帮助你完成,代码变得简洁而优雅。极度适合研究人员。
和学习一门语言一样,要学习一个框架,它背后的设计哲学很重要。
Pytorch Lightning的设计哲学可以概括为以下四点:
- 保证代码结构拥有最大的灵活性;
- 将不必要的样板代码抽象出来,但是同时保持其在必要时可以访问到;
- 系统应该是独立的;
- 深度学习的代码可以解耦成四部分,各司其职。首先是数据加载部分,可以使用pytorch提供的dataloader或者使用pytorch lightning的LightningDataModule; 其次是模型部分,这里是research idea主要体现的地方,也就是LightningModule;然后是训练部分,主要工程代码,由pytorch lightning的Trainer完成,不需要自己实现;最后相对次要的research相关代码,比如说logging等,在CallBacks中实现
如何安装
通过PyPI安装
pip install pytorch-lightning
通过conda安装
conda install pytorch-lightning -c conda-forge
Lighting训练的人性化设计
文章末有示例代码,在运行后,可以看到以下界面。
有几点我个人觉得还是很惊艳的,可以帮助提高日常工作效率,
- 本机GPU以及TPU检测,lightning是支持GPU以及TPU训练的,并且可以很方便的指定具体用哪块GPU来训练
- 参数量显示,在网络设计好之后我们可能需要花费时间去计算网络的参数量,lightning可以直接在代码运行后显示出来
- 进度条,以及每个epoch,iteration的耗时显示,可以清晰的帮助我们看到时间消耗。
- 合理的代码改进建议,比如说在数据加载中,lightning检测到我的机器有4核,但是在数据中我们只使用了一个核,可能会是训练速度的瓶颈,lightning建议通过设置num_workers的值,来提高训练速度。可以帮助我们快速的写出高质量的代码。
- 完善的模型以及日志机制,在代码运行完成后,lightning会自动帮我们保存相应的checkpoints,对应的日志文件以及hyperparameter
Lightning效率之美
lightning能够将一些繁琐的超参设计流程自动化,帮助我们提高工作效率。
1. Learning Rate Finder
在深度学习训练中,合理的学习率设置可以帮助网络取得更好的表现以及更快的收敛,lighting提供auto_lr_find的设置,在初始训练的时候帮我们找到一个比较合理的学习率,而不是单纯靠猜。
2. Auto Scale Batch Size
在机器学习任务中,一般来说,大的batch size可以提供更稳定的梯度,但是人为调节batch size可能会比较繁琐。Pytorch Lightning可以自动根据机器内存的大小缩放batch size,以达到最适合当前任务的一个值。
3. GPU 16-precision Training
更多自动化操作可以参考这个网址,PyTorch Lightning Documentation
参考资料
https://github.com/PyTorchLightning/pytorch-lightning
https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html
样例代码
import os
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
class LitAutoEncoder(pl.LightningModule):
def __init__(self) -> None:
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28))
def forward(self, x):
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
self.log('train_loss', loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [55000, 5000])
autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))