基于Pytorch的Mnist手写数字识别

# !/usr/bin/env Python3
# -*- coding: utf-8 -*-
# @version: v1.0
# @Author   : Meng Li
# @contact: 925762221@qq.com
# @FILE     : torch_mnist.py
# @Time     : 2022/5/31 9:29
# @Software : PyCharm
# @site:
# @Description : 自己动手实现mnist数据集的10分类任务
# 同等条件下,batch_size 越小,模型越收敛。但是更容易震荡。learning_rate越小,模型收敛速度越慢

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torchsummary
import torch.optim as optim
from torch.utils.data import Dataset
import matplotlib.pylab as plt


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 100)
        self.fc2 = nn.Linear(100, 10)
        self.crition = torch.nn.CrossEntropyLoss()
        pass

    def forward(self, x, y):
        batch_size, _, h, w = x.size()
        x = x.view(-1, h * w)
        output = F.relu(self.fc1(x))
        output = self.fc2(output)
        loss = self.crition(output, y)
        val, index = torch.max(output, 1)
        acc = torch.eq(index, y).float().cpu().sum()
        return loss, acc.float() / y.size(0), index


def train():
    net = Net()
    show_sum_flg = False
    if show_sum_flg:
        torchsummary.summary(net, (28, 28))
    train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
                                            download=False)
    batch_size = 64
    learning_rate = 0.001
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)

    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)
    epoch = 30
    max_acc = 0
    acc = 0
    for i in range(epoch):
        for image, label in train_iter:
            loss, acc, _ = net(image, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print("epoch {0}  acc {1}".format(i, acc))
        if acc > max_acc:
            max_acc = acc
            torch.save(net, 'limeng.pth')


def test():
    net = torch.load('limeng.pth')
    net.eval()
    train_data = torchvision.datasets.MNIST(root="./", train=True, transform=torchvision.transforms.ToTensor(),
                                            download=False)
    batch_size = 10
    train_iter = torch.utils.data.DataLoader(train_data, batch_size, shuffle=True)

    for image, label in train_iter:
        _, _, predict = net(image, label)
        for i in range(batch_size):
            imagei = image[i, 0, :, :]
            plt.subplot(2, 5, i+1)
            plt.imshow(imagei)
            plt.title("{0}".format(predict[i]))
        plt.show()
        break


if __name__ == '__main__':
    # train()
    test()

先上代码,工作期间接触了Tensorflow和Pytorch两种框架,但是总得来说,pytorch由于编码语法规范更接近于python原生语法,所以更容易上手。作为深度学习中的"hello world",还是有必要自己写一下整个数据输入到模型训练,模型保存再到模型测试的全流程。

测试模型时,运行效果图大概是这样的:

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值