由于我们在使用pytorch时遇到的各种问题,出现了Pytorch-Lighting这一框架。
Pytorch-Lighting 的一大特点是把模型和系统分开来看。系统定义了一组模型如何相互交互,如GAN(生成器网络与判别器网络)、Seq2Seq(Encoder与Decoder网络)和Bert。有时候问题只涉及一个模型,例如UNet、ResNet等,那么这个系统则可以是一个通用的系统,用于描述模型如何使用,并可以被复用到很多其他项目。
Pytorch-Lighting 框架下,每个网络包含了如何训练、如何测试、优化器定义等内容。
Lightning社区中的Face Mask Detector: https://towardsdatascience.com/how-i-built-a-face-mask-detector-for-covid-19-using-pytorch-lightning-67eb3752fd61
以手写体识别为例程的PyTorch Lightning教程:https://colab.research.google.com/drive/1Mowb4NzWlRCxzAFjOIJqUmmk_wAT-XP3
1.安装
pip install pytorch-lightning
2.网络设计
import torch
from torch import nn
import pytorch_lightning as pl
class LightningMNISTClassifier(pl.LightningModule):
def __init__(self):
super(LightningMNISTClassifier, self).__init__()
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.siz()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
# layer 1
x = self.layer_1(x)
x = torch.relu(x)
# layer 2
x = self.layer_2(x)
x = torch.relu(x)
# layer 3
x = self.layer_3(x)
# probability distribution over labels
x = torch.log_softmax(x, dim=1)
return x
可以发现PyTorch和LP是几乎一样的。
# restore with PyTorch
pytorch_model = MNISTClassifier()
pytorch_model.load_state_dict(torch.load(PATH))
model.eval()
lightning_model = LightningMNISTClassifier.load_from_checkpoint(PATH)
lightning_model.eval()
3. 数据
让我们生成MNIST数据集的三个数据集——训练、验证和测试。
在pytorch中,数据集被添加到Dataloader中,Dataloader处理数据集的加载、shuffling和batching 。
- 图像转换。
- 生成训练、验证和测试数据集。
- 将每个数据集载入到DataLoader中。
如下
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
# ----------------
# TRANSFORMS
# ----------------
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# ----------------
# TRAINING, VAL DATA
# ----------------
mnist_train = MNIST(os.getcwd(), train=True, download=True)
# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])
# ----------------
# TEST DATA
# ----------------
mnist_test = MNIST(os.getcwd(), train=False, download=True)
# ----------------
# DATALOADERS
# ----------------
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)
在PyTorch中,这种数据加载可以在训练程序的任何地方进行,而在PyTorch Lightning中,dataloader可以直接使用,也可以在LightningDataModule下将三种方法组合起来使用
train_dataloader()
val_dataloader()
test_dataloader()
还有第四个方法是用于数据准备/下载的。
prepare_data()
Lightning采用这种方法,使每个用Lightning实现的模型都遵循相同的结构。这使得代码具有极高的可读性和组织性。
也就是说,当你遇到一个使用Lightning的项目,从代码中能清楚地知道数据处理/下载发生在哪里。
LP如下:
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage):
# transforms for images
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
def test_dataloader(self):
return DataLoader(