Pytorch-Lighting(简称pl),它其实就是一个轻量级的PyTorch库
它把研究代码与工程代码相分离,还将PyTorch代码结构化,更加直观的展现数据操作过程。
这样,更加易于理解,不易出错,本来很冗长的代码一下子就变得轻便了,对AI研究者十分的友好。
PyTorch Lightning 就是pytorch的keras
之前在使用Pytorch的时候觉得在完成dataloader和model之后,还要写一堆train&test code有些过于繁琐。
在短暂地了解了Pytorch Lightning之后,发现这个框架可以在一定程度上解决这个问题,这也是作者开发Pytorch lightning框架的初衷——“花更多的时间在研究上,花更少的时间在工程上”。
conda install pytorch-lightning -c conda-forge
pytorch lightning通过提供LightningModule和LightningDataModule,使得在用pytorch编写网络模型时,加载数据、分割数据集、训练、验证、测试、计算指标的代码全部都能很好的组织起来,显得主程序调用时,代码简洁可读性大幅度提升
pytorch 和 pl 本质上代码是完全相同的。只不过pytorch需要自己造轮子(如model, dataloader, loss, train,test,checkpoint, save model等等都需要自己写),而pl 把这些模块都结构化了(类似keras)。
从下面的图片来看两者的区别
import os import torch from torch import nn import torch.nn.functional as F from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64, 3) ) self.decoder = nn.Sequential( nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28) ) def forward(self, x): # in lightning, forward defines the prediction(预测)/inference(推理) actions embedding = self.encoder(x) return embedding def training_step(self, batch, batch_idx): # training_step defined the train loop. # It is independent of forward x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) # Logging to TensorBoard by default self.log('train_loss', loss) #self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) train_loader = DataLoader(dataset, batch_size=10) # init model autoencoder = LitAutoEncoder() # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) # trainer = pl.Trainer(gpus=8) (if you have GPUs) trainer = pl.Trainer( gpus=1, auto_scale_batch_size=True ) trainer.fit(autoencoder, train_loader)
PyTorch Lightning (pl)(简约版pytorch)(一) —— 简介
于 2021-10-18 16:23:41 首次发布