pytorch lightning 使用记录

pytorch lightning

由于我们在使用pytorch时遇到的各种问题,出现了Pytorch-Lighting这一框架。
在这里插入图片描述

Pytorch-Lighting 的一大特点是把模型和系统分开来看。系统定义了一组模型如何相互交互,如GAN(生成器网络与判别器网络)、Seq2Seq(Encoder与Decoder网络)和Bert。有时候问题只涉及一个模型,例如UNet、ResNet等,那么这个系统则可以是一个通用的系统,用于描述模型如何使用,并可以被复用到很多其他项目。
Pytorch-Lighting 框架下,每个网络包含了如何训练、如何测试、优化器定义等内容。

在这里插入图片描述

Lightning社区中的Face Mask Detector: https://towardsdatascience.com/how-i-built-a-face-mask-detector-for-covid-19-using-pytorch-lightning-67eb3752fd61

在这里插入图片描述

以手写体识别为例程的PyTorch Lightning教程:https://colab.research.google.com/drive/1Mowb4NzWlRCxzAFjOIJqUmmk_wAT-XP3

1.安装

pip install pytorch-lightning

2.网络设计

import torch
from torch import nn
import pytorch_lightning as pl

class LightningMNISTClassifier(pl.LightningModule):

  def __init__(self):
    super(LightningMNISTClassifier, self).__init__()

    # mnist images are (1, 28, 28) (channels, width, height) 
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.siz()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)

    # layer 1
    x = self.layer_1(x)
    x = torch.relu(x)

    # layer 2
    x = self.layer_2(x)
    x = torch.relu(x)

    # layer 3
    x = self.layer_3(x)

    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)

    return x

可以发现PyTorch和LP是几乎一样的。

# restore with PyTorch
pytorch_model = MNISTClassifier()
pytorch_model.load_state_dict(torch.load(PATH))
model.eval()


lightning_model = LightningMNISTClassifier.load_from_checkpoint(PATH)
lightning_model.eval()

3. 数据

让我们生成MNIST数据集的三个数据集——训练、验证和测试。
在pytorch中,数据集被添加到Dataloader中,Dataloader处理数据集的加载、shuffling和batching 。

  1. 图像转换。
  2. 生成训练、验证和测试数据集。
  3. 将每个数据集载入到DataLoader中。

如下

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms


# ----------------
# TRANSFORMS
# ----------------
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(), 
                              transforms.Normalize((0.1307,), (0.3081,))])

# ----------------
# TRAINING, VAL DATA
# ----------------
mnist_train = MNIST(os.getcwd(), train=True, download=True)

# train (55,000 images), val split (5,000 images)
mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

# ----------------
# TEST DATA
# ----------------
mnist_test = MNIST(os.getcwd(), train=False, download=True)

# ----------------
# DATALOADERS
# ----------------
# The dataloaders handle shuffling, batching, etc...
mnist_train = DataLoader(mnist_train, batch_size=64)
mnist_val = DataLoader(mnist_val, batch_size=64)
mnist_test = DataLoader(mnist_test, batch_size=64)

在PyTorch中,这种数据加载可以在训练程序的任何地方进行,而在PyTorch Lightning中,dataloader可以直接使用,也可以在LightningDataModule下将三种方法组合起来使用

train_dataloader()
val_dataloader()
test_dataloader()

还有第四个方法是用于数据准备/下载的。

prepare_data()

Lightning采用这种方法,使每个用Lightning实现的模型都遵循相同的结构。这使得代码具有极高的可读性和组织性。

也就是说,当你遇到一个使用Lightning的项目,从代码中能清楚地知道数据处理/下载发生在哪里。
LP如下:

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

class MNISTDataModule(pl.LightningDataModule):

  def setup(self, stage):
    # transforms for images
    transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
      
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)
    
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train, batch_size=64)

  def val_dataloader(self):
    return DataLoader(self.mnist_val, batch_size=64)

  def test_dataloader(self):
    return DataLoader(
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值