MNIST手写数字识别

进入到研究生阶段了,从头学一下Pytorch,在这个小破站上记录一下自己的学习过程。
本文使用的是Pytorch来做手写数字的识别。

step0:先引入一些相关的包和库

import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import torchvision
from matplotlib import pyplot as plt


from utils import plot_image,plot_curve,one_hot

这里的utils是定义的一些辅助工具,包括loss下降的绘图函数和one_hot编码及图片显示的辅助函数。代码如下:
utils.py

# !/usr/bin/python3
# -*- coding:utf-8 -*-
# Author:WeiFeng Liu
# @Time: 2021/10/26 下午4:47

import torch
from matplotlib import pyplot as plt

###loss下降
def plot_curve(data):
    fig = plt.figure()
    plt.plot(range(len(data)), data, color='blue')
    plt.legend(['value'], loc='upper right')
    plt.xlabel('step')
    plt.ylabel('value')
    plt.show()



def plot_image(img,label,name):

    fig = plt.figure()
    for i in range(6):
        plt.subplot(2,3,i+1)
        plt.tight_layout()
        plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
        plt.title("{}:{}".format(name,label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()

def one_hot(labels,depth=10):
    out = torch.zeros(labels.size(0),depth)
    idx = torch.LongTensor(labels).view(-1,1)
    out.scatter_(dim = 1, index = idx,value=1)
    return out

step1:加载数据
使用torch的DataLoader方法加载数据,MNIST数据集中的图片大小为28*28,比较小,batch_size可以设置大一点。

batch_size = 512
###step1  load dataset
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data',train=True,download=True,
                               transform = torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   #数据归一化
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081,))
                               ])),
    batch_size = batch_size,shuffle = True
)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data/',train=False,download=True,
                               transform = torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,),(0.3081,))
                               ])),
    batch_size = batch_size , shuffle = False
)

transforms.Compose方法将数据转为Tensor和做数据归一化,训练集中设置shuffle=True是将训练数据打乱.

step2:定义网络结构
使用简单的三层线性模型来做简单的识别。

class Net(nn.Module):

    def __init__(self):
        super(Net,self).__init__()

        self.fc1 = nn.Linear(28*28,256)
        self.fc2 = nn.Linear(256,64)
        self.fc3 = nn.Linear(64,10)

    def forward(self, x):
        #x:[batch_size,1,28,28]
        x = F.relu(self.fc1(x))

        x = F.relu(self.fc2(x))

        x = self.fc3(x)

        return x

step3:train
训练3个epoch

train_loss = []
net =Net()
optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)

for epoch in range(3):
    for batch_idx,(x,y) in enumerate(train_loader):
        # x:[batch_size,1,28,28]
        #将x打平成二维的
        # y:batch_size
        x = x.view(x.size(0),28*28)
        out = net(x)
        y_onehot = one_hot(y)

        ##lose = mse(y,out)

        loss = F.mse_loss(out,y_onehot)

        optimizer.zero_grad() #梯度清零
        loss.backward() #计算梯度
        optimizer.step() #更新参数

        ##打印loss
        train_loss.append(loss.item())
        if batch_idx % 10 == 0:
            print(epoch,batch_idx,loss.item())
plot_curve(train_loss)

20211027 205211屏幕截图.png

step4:test
最后在验证集测试训练的准确率

total_correct = 0
for x,y in test_loader:
    x = x.view(x.size(0),28*28)
    out = net(x)
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct
total_num = len(test_loader.dataset)

acc = total_correct / total_num
print("test acc:",acc)

20211027 205405屏幕截图.png

  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值