pytorch模型训练使用(以FashionMNIST为例)

·


前言

根据之前学习到的pytorch一整个系列的流程,可以自己去写一个深度学习模型,并且进行一系列完整的测试。


本文采用fashion_mnist数据集进行训练,手写一个网络模型,在测试集上的准确率达到 。代码结构如下图所示:
在这里插入图片描述

一、数据集

数据集下载:

trans = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.RandomHorizontalFlip(p=0.5),  # 水平0.5概率翻转
        torchvision.transforms.RandomRotation(degrees=30)  # 30度旋转
    ])
mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=False
    )
mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=False
    )

使用的是pytorch内部的数据集,对于自定义的数据集需要进行相关的预处理操作,可以看我之前的关于dataset,dataloader的博客。

train_dataload = data.DataLoader(mnist_train, batch_size=128, shuffle=True)
test_dataload = data.DataLoader(mnist_test, batch_size=128, shuffle=True)

二、模型

模型总共6层,4层卷积层,2层全连接层。model.py文件如下

from torch import nn
import torch

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=1, padding=1),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 256),
            nn.Dropout(p=0.5),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        return self.model1(x)

三、模型训练

模型训练train.py文件:

import torch
from torch.utils import data
import torchvision
import matplotlib.pyplot as plt
from model import model


def train_model(train_dataloader, test_dataloader, train_size, test_size, epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = model()
    net = net.to(device)

    optimer = torch.optim.Adam(net.parameters(), lr=0.001)

    loss_fn = torch.nn.CrossEntropyLoss()
    loss_fn = loss_fn.to(device)

    train_loss = []
    test_loss = []
    test_accuracy = []
    step = 0
    for epoch in range(epochs):
        net.train()
        total_train_loss = 0
        print("-------------------第{}轮训练开始------------------".format(epoch + 1))
        for data in train_dataloader:
            imgs, targets = data
            imgs, targets = imgs.to(device), targets.to(device)
            output = net(imgs)
            optimer.zero_grad()
            loss = loss_fn(output, targets)
            total_train_loss += loss.item()
            loss.backward()
            optimer.step()
            step += 128
            if step % 1024 == 0:
                print("第{}次train,loss:{}".format(step, loss / len(targets)))

        train_loss.append(total_train_loss / train_size)
        net.eval()
        total_test_loss = 0
        total_test_accuracy = 0
        with torch.no_grad():
            for data in test_dataloader:
                imgs, targets = data
                imgs, targets = imgs.to(device), targets.to(device)
                output = net(imgs)
                loss = loss_fn(output, targets)
                total_test_loss += loss.item()
                total_test_accuracy += (output.argmax(1) == targets).sum()
        test_loss.append(total_test_loss / test_size)
        test_accuracy.append(total_test_accuracy / test_size)
        print("test集,loss:{},accuracy:{}".format(total_test_loss / test_size,
                                                 total_test_accuracy / test_size))
        if (epoch + 1) % epochs == 0:
            torch.save(net, "fashion_mnistmodel{}.pth".format(epoch))

    plt.xlabel("epoch")
    plt.ylabel("val")
    plt.plot(range(1, epochs + 1), train_loss,
             range(1, epochs + 1), test_loss,
             range(1, epochs + 1), test_accuracy)
    plt.legend(["train_loss", "test_loss", "test_accuracy"])
    plt.show()


if __name__ == '__main__':
    trans = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.RandomHorizontalFlip(p=0.5),  # 水平0.5概率翻转
        torchvision.transforms.RandomRotation(degrees=30)  # 30度旋转
    ])
    mnist_train = torchvision.datasets.FashionMNIST(
        root="../data", train=True, transform=trans, download=False
    )
    mnist_test = torchvision.datasets.FashionMNIST(
        root="../data", train=False, transform=trans, download=False
    )
    train_dataload = data.DataLoader(mnist_train, batch_size=128, shuffle=True)
    test_dataload = data.DataLoader(mnist_test, batch_size=128, shuffle=True)

    train_model(train_dataload, test_dataload, mnist_train.__len__(), mnist_test.__len__(), 35)

绘制得到的loss图像如下所示:
在这里插入图片描述

四、模型测试

对于预测后需要进行标签对应,代码如下:

def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                   'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [text_labels[int(i)] for i in labels]

可以使用plt来绘制一下fashion_mnist的图像,代码如下:

def show_images(imgs, num_rows, num_cols, titles=None):
    """绘制图像列表"""
    for i, x in enumerate(imgs):
        # 绘制一个n*m个图片围成的画布
        plt.subplot(num_rows, num_cols, i + 1)
        plt.imshow(x.squeeze(0))
        plt.title(titles[i])
        plt.xticks([])
        plt.yticks([])
    plt.show()

x, y = next(iter(data.DataLoader(mnist_train, batch_size=18, shuffle=True)))
show_images(x.reshape(18, 28, 28), 3, 6, titles=get_fashion_mnist_labels(y))

在这里插入图片描述

总结

本文实现了fashion_mnist的数据集模型搭建和测试,因为只是记录一下pytorch搭建网络的一个大体过程,所以很多地方没有进行解释和注释,有不懂的欢迎大家在评论区或私信对我提问。

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Master___Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值