(5)深度学习学习笔记-多层感知机-pytorch lightning版


前言

pytorch lighting是导师推荐给我学习的一个轻量级的PyTorch库,代码干净简洁,使用pl更容易理解ML代码,对于初学者的我还是相对友好的。
pytorch lightning官网网址
https://lightning.ai/docs/pytorch/stable/levels/core_skills.html


多层感知机pl代码

1.引入库

代码如下:

import os
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from torchvision import transforms
from torch.utils import data

# 处理anaconda和torch重复文件
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

2.读入数据

代码如下:(可以直接把download改为true下载)

def load_data_fashion_mnist(batch_size, resize=None):  # 图片28*28*1
    """在本地读入Fashion-MNIST数据集"""
    trans = [transforms.ToTensor()]  # 把图片转换为pytorch tensor
    if resize:
        trans.insert(0, transforms.Resize(resize))
    trans = transforms.Compose(trans)

    mnist_train = torchvision.datasets.FashionMNIST(
        root="D:/python_project/fashion-mnist-master/fashion-mnist-master/data/fashion",
        train=True,
        transform=trans,
        download=False
    )
    mnist_test = torchvision.datasets.FashionMNIST(
        root="D:/python_project/fashion-mnist-master/fashion-mnist-master/data/fashion",
        train=False,
        transform=trans,
        download=False
    )
    return (data.DataLoader(mnist_train, batch_size, shuffle=True,
                            num_workers=0),
            data.DataLoader(mnist_test, batch_size, shuffle=False,
                            num_workers=0))

3.pl二层感知机

# 二层感知机
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28*28, 256), nn.ReLU(), nn.Linear(256, 10))

    def forward(self, x):
        return self.l1(x)

class Perceptron(pl.LightningModule):
# pl模块和nn模块交互

    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.encoder(x)
        loss = F.cross_entropy(y_hat, y)

        print("train_loss=", loss)
        return loss

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.encoder(x)
        test_loss = F.cross_entropy(y_hat, y)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-1)
        return optimizer


batch_size = 256
# 训练/测试集
train_loader, test_loader = load_data_fashion_mnist(batch_size)

# 模型
model = Perceptron(Encoder())
# 训练模型
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=train_loader)
# 测试
trainer.test(dataloaders=test_loader)

总结

更多pl的方法,可以到pl官网查看

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值