利用MINIST数据集识别手写数字

一、MINIST数据集

MNIST 数据集可在 http://yann.lecun.com/exdb/mnist/ 获取, 它包含了四个部分:

Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解压后 47 MB, 包含 60,000 个样本)

Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解压后 60 KB, 包含 60,000 个标签)

Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解压后 7.8 MB, 包含 10,000 个样本)

Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解压后 10 KB, 包含 10,000 个标签)

MNIST 数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST). 训练集 (training set) 由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员. 测试集(test set) 也是同样比例的手写数字数据.  在 MNIST 数据集中的每张图片由 28 x 28 个像素点构成, 每个像素点用一个灰度值表示

二、利用SoftMax进行分类

 

将Softmax输出的值,进行One-hot编码。One-hot编码后的预测值和真实值两分布的差异即为交叉熵损失函数。

 

三、网络架构图

 

输入的数据为(N,1,28,28) ,输出的数据为(N,10),每张图片,输出10个分类对应的数值。

四、代码实现

数据预处理,代码几处需解释的地方:

1. toTensor()能够把灰度范围从0-255变换到0-1之间,而后面的transform.Normalize()则把0-1变换到(-1,1). 期望为0.1307,方差为

0.3081的正太分布。

2. optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)

第三个参数momentum是冲量,“冲量”这个概念源自于物理中的力学,表示力对时间的积累效应。

在普通的梯度下降法x+=v中,每次x的更新量v为v=−dx∗lr,其中dx为目标函数func(x)对x的一阶导数。

当使用冲量时,则把每次x的更新量v考虑为本次的梯度下降量−dx∗lr与上次x的更新量v乘上一个介于[0,1][0,1]的因子momentum的和,即

v ′ = − d x ∗ l r + v ∗ m o m e m t u m v^{'}=−dx∗lr+v∗momemtum

当本次梯度下降- dx * lr的方向与上次更新量v的方向相同时,上次的更新量能够对本次的搜索起到一个正向加速的作用。
当本次梯度下降- dx * lr的方向与上次更新量v的方向相反时,上次的更新量能够对本次的搜索起到一个减速的作用。

3. enumerate()是Python的内置函数,

功能:将一个可遍历的数据对象(如列表、元组、字典和字符串)组合成一个索引序列,同时列出数据下标和数据(索引 值),一般配合for循环使用。

4. _,predicted=torch.max(outputs.data,dim=1)

dim=1,一行为1维度,一列为:0 。输出一行的最大值及其下标。

import torch

from torchvision import transforms

from torchvision import datasets

from torch.utils.data import  DataLoader

import torch.nn.functional as F

import torch.optim as optim



batch_size=64

#利用transforms 转换为tensor 类型的图片,

transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))

                               ])

train_dataset=datasets.MNIST(root='../data/mnist',train=True,download=True,transform=transform)

train_loader= DataLoader(train_dataset,shuffle=True,batch_size=batch_size)





test_dataset=datasets.MNIST(root='../data/mnist',train=False,download=True,transform=transform)

test_loader=DataLoader(train_dataset,shuffle=False,batch_size=batch_size)





class Net(torch.nn.Module):

    def __init__(self):

        super(Net,self).__init__()

        self.l1=torch.nn.Linear(784,512)

        self.l2=torch.nn.Linear(512,256)

        self.l3=torch.nn.linear(256,128)

        self.l4=torch.nn.Linear(128,64)

        self.l5=torch.nn.Linear(64,10)



    def forward(self,x):

        x=x.view(-1,784)

        x=F.relu(self.l1(x))

        x=F.relu(self.l2(x))

        x=F.relu(self.l3(x))

        x=F.relu(self.l4(x))

        return self.l5(x)



model=Net()



crierion=torch.nn.CrossEntropyLoss()

optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)



def train(epoch):

    running_loss=0.0

    for batch_idx,data in enumerate(train_loader,0):

        inputs,target=data

        optimizer.zero_grad()



        #forward +backward+update

        outputs=model(inputs)

        loss=crierion(outputs,target)

        loss.backward()

        optimizer.step()



        running_loss+=loss.item()

        if batch_idx%300==299:

            print('[%d,%5d] loss:%.3f' %(epoch+1,batch_idx+1,running_loss/300))

            running_loss==0.0



def test():

    correct=0

    total=0

    with torch.no_grad():

        for data in test_loader:

            images,labels=data

            outputs=model(images)

            _,predicted=torch.max(outputs.data,dim=1)

            total+=labels.size(0)

            correct+=(predicted==labels).sum().item()

    print('Accuacy on test set%d %%'%(100*correct/total))





if __name__=='__main__':

    for epoch in range(10):

        train(epoch)

        test()

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值