pytorch lightning 手写数字分类实例 (三)

今天通过手写数字来学习如何利用pytorch-lightning进行分类
代码同第二部分的差不多,新增了断点训练和测试部分。
项目使用jupyter notebook演示
此部分代码很简单,小白也能上手,赶快来试一试吧~~~

该系列还有
pytorch-lightning入门(一)—— 初了解
如何从Pytorch 到 Pytorch Lightning (二) | 简要介绍

import torch
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torch.nn import functional as F
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import os
from pytorch_lightning import seed_everything
import numpy as np
import matplotlib.pyplot as plt

SET SEED

# 首先设置随机数种子
seed_everything(seed=42)

在这里插入图片描述

# 定义模型
class LightningMNISTClassifier(pl.LightningModule):

    def __init__(self):
        super().__init__()

        # mnist images are (1, 28, 28) (channels, width, height)
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()

        # (b, 1, 28, 28) -> (b, 1*28*28)
        x = x.view(batch_size, -1)

        # layer 1 (b, 1*28*28) -> (b, 128)
        x = self.layer_1(x)
        x = torch.relu(x)

        # layer 2 (b, 128) -> (b, 256)
        x = self.layer_2(x)
        x = torch.relu(x)

        # layer 3 (b, 256) -> (b, 10)
        x = self.layer_3(x)

        # probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

dataloader

# data
# transforms for images
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

# prepare transforms standard to MNIST
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)

train_dataloader = DataLoader(mnist_train, batch_size=64)
val_loader = DataLoader(mnist_test, batch_size=64)

没有下载的会自动进行下载,如果速度慢,就手动下载保存。有VPN 可以直接运行代码

training

接下来开始训练,提供两种训练方法

  • 从头训练
  • 从断点开始训练
# train
model = LightningMNISTClassifier()

# resume training
RESUME = False
if RESUME:
    resume_checkpoint_dir = './lightning_logs/version_1/checkpoints/'
    checkpoint_path = os.listdir(resume_checkpoint_dir)[0]
    resume_checkpoint_path = resume_checkpoint_dir + checkpoint_path

    trainer = pl.Trainer(gpus='1',
                         max_epochs=10,
                         resume_from_checkpoint=resume_checkpoint_path)

    trainer.fit(model, train_dataloader, val_loader)
else:
    trainer = pl.Trainer(gpus='1', max_epochs=20)

    trainer.fit(model, train_dataloader, val_loader)

输出包括:
在这里插入图片描述
在这里插入图片描述

训练结果默认保存在文件夹: ./lightning_logs.
会根据你运行的次数自动命名这是版本x。

在这里插入图片描述
运行期间可以打开tensorboard 查看运行情况

在这里插入图片描述

testing

# test
checkpoint_dir = 'lightning_logs/version_2/checkpoints/'
checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]
model = LightningMNISTClassifier()
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])

inputs, labels = next(iter(val_loader))
# inference
outputs = model(inputs)

这里,我只测试一个batch的数据。

测试结果显示

def imshow(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    plt.show()
import torchvision
# 求outputs最大索引
_, preds = torch.max(outputs, dim=1)
# print images and ground truth
imshow(torchvision.utils.make_grid(inputs))
print('GroundTruth:', labels)
print ('Prediction:', preds)

在这里插入图片描述
可见,这部分的准确率为100%

tips:使用lightning模式,用jupyter notebook显示进度条不是很良好,应该要进行相应的设置,我这里不太想去研究了。在终端运行,显示效果更好。

此部分代码很简单,赶快来试一试吧

文章持续更新,可以关注微信公众号【医学图像人工智能实战营】,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~

我是Tina, 我们下篇博客见~

最后,求点赞,评论,收藏。或者一键三连
在这里插入图片描述

  • 8
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Tina姐

我就看看有没有会打赏我

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值