Pytorch-lightning入门实例

Pytorch-lightning入门实例

本文将通过Colab平台及MINIST数据集指导你了解Pytorch-lightning的核心组成。

注意:任何的深度学习、机器学习的Pytorch工程都可以转变为lightning结构

从MNIST到自动编码器

安装Lightning

虽然说安装Lightning非常的容易,但还是建议大家在本地通过conda来安装Lightning

conda activate my_env
pip install pytorch-lightning

当你运行在在Google Colab上时,需要执行

!pip install pytorch-lightning

截屏2020-12-27 上午10.29.06

截屏2020-12-27 上午10.29.19

当出现上述内容是,代表Lightning已经安装完成。

或者你也可以使用conda命令来安装

conda install pytorch-lightning -c conda-forge

The research

The Model

Lightning由以下核心部分组成:

  • The model
  • The optimizers
  • The train/val/test steps

我们通过Model引入这一部分,下面我们将会设计一个三层的神经网络模型

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMNIST(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    x = F.log_softmax(x, dim=1)
    return x

可以观察到我们的LitMNIST类并不是像Pytorch中继承于torch.nn.Module类,而是Pytorch-lightning中的LightningModule类。该类与Pytorch中的torch.nn.Module类并没有太大的区别,只是增加了一些功能,我们可以像在Pytorch中使用torch.nn.Module类一样使用它。例如:

net = LitMNIST()
x = torch.randn(1, 1, 28, 28)
out = net(x)

截屏2020-12-27 上午10.39.35

下一步我们添加训练过程training_step,它继承于LightningModule,其中包含了所有训练过程中的逻辑内容

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
Data

LIghtning运行在dataloders上,以下是数据加载部分的代码:

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)

我比较推荐采用下面的方式加载数据:

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

使用上面的DataModule可以更加方便的分项数据定义

# use an MNIST dataset
mnist_dm = MNISTDatamodule()
model = LitModel(num_classes=mnist_dm.num_classes)
trainer.fit(model, mnist_dm)

# or other datasets with the same model
imagenet_dm = ImagenetDatamodule()
model = LitModel(num_classes=imagenet_dm.num_classes)
trainer.fit(model, imagenet_dm)
Optimizer

在Pytorch中,我们通过下列方式来Optimizer代码:

from torch.optim import Adam
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)

在Lightning中,方法类似,但是我们会将其包含在configure_optimizers()方法中

class LitMNIST(LightningModule):

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
Training step

我们将其写在training_step()`方法中

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
Training

目前为止,我们已经完成了四个主要的代码部分

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMNIST(LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)

        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

最后执行下列代码训练我们的数据:

from pytorch_lightning import Trainer
dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(gpus=8)
trainer.fit(model, dm)

训练时间比较长

小结

我在博客中的每篇文章都是我一字一句敲出来的,转载的文章我也注明了出处,表示对原作者的尊重。同时也希望大家都能尊重我的付出。

最后,也希望大家关注我的个人博客HD Blog

谢谢~

  • 12
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Johnny__Wang__

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值