CNN网络实现手写数字(MNIST)识别 代码分析

本文详细介绍了使用Pytorch框架构建的CNN网络如何对手写数字(MNIST)进行识别,包括数据准备、网络配置、模型训练和评估。网络结构包含两个卷积层和三个全连接层,通过交叉熵损失函数进行优化,最终在测试集上取得了高准确率。
摘要由CSDN通过智能技术生成

CNN网络实现手写数字(MNIST)识别 代码分析(自学用)

Github代码源文件
本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别

#导入需要的包
import numpy as np   //第三方库,用于进行科学计算
import torch 
from torch import nn
from PIL import Image  // Python Image Library,python第三方图像处理库
import matplotlib.pyplot as plt //python的绘图库 pyplot:matplotlib的绘图框架
import os //提供了丰富的方法来处理文件和目录
from torchvision import datasets, transforms,utils //提供很多数据集的下载,包括COCO,ImageNet,CIFCAR等

1. 准备数据

(1)数据集介绍
MNIST数据集包含60000个训练集和10000测试数据集。分为图片和标签,图片是28*28的像素矩阵,标签为0~9共10个数字。

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize(mean=[0.5],std=[0.5])])
                              
//Compos把多种数据处理的方法集合在一起
//使用transforms进行Tensor格式转换,将灰度范围从0-255变换到0-1之间
//批标准化(Batch Normalization),其作用就是先将输入归一化到(0,1),再使用公式"(x-mean)/std",将每个元素分布到(-1,1)
train_data = datasets.MNIST(root = "./data/"//root为数据集存放的路
                           transform=transform, //transform指定数据集导入的时候需要进行的变换
                           train = True,    //train设置为true表明导入的是训练集合,否则是测试集合
                           download = True) //如果为true,请从互联网下载数据集,然后将其放在根目录中。 如果数据集已经下载,则不是再次下载。

test_data = datasets.MNIST(root="./data/",
                          transform = transform,
                          train = False)
//train_data 的个数:60000个训练样
//test_data 的个数:10000个训练样本 
//一个样本的格式为[data,label],第一个存放数据,第二个存放标签
                     
train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,
                                         shuffle=True,num_workers=2)
test_loader = torch.utils.data.DataLoader(test_data,batch_size=64,
                                         shuffle=True,num_workers=2)
//设置batch_size表示每次训练的样本数量 ,加载器中的基本单位是一个batch的数据 ,这里是64

//所以train_loader 的长度是60000/64 = 938 个batch,test_loader 的长度是10000/64= 157 个batch
                          
//shuffle 将序列的所有元素随机排序。
//num_workers 表示用多少个子进程加载数据

从二维数组生成一张图片

oneimg,label = train_data[0]
oneimg = oneimg.numpy().transpose(1,2,0) //numpy.transpose默认第一个方括号“[]”为 0轴 ,第二个方括号为 1轴...所以有着交换轴改变矩阵序列的作,(x=0,y=1,z=2),新的x是原来的y轴大小,新的y是原来的z轴大小,新的z是原来的x大小
std = [0.5]  //标准差
mean = [0.5] //平均值
oneimg = oneimg * std + mean
oneimg.resize(28,28)
plt.imshow(oneimg)
plt.show()

在这里插入图片描述
从三维生成一张黑白图片

oneimg,label = train_data[0]
grid = utils.make_grid(oneimg) //make_grid的作用是将若干幅图像拼成一幅图像。在需要展示一批数据时很有用。
grid = grid.numpy().transpose(1,2,0) 
std = [0.5]
mean = [0.5]
grid = grid * std + mean
plt.imshow(grid)
plt.show(

在这里插入图片描述
输出一个batch的图片和标签

images, lables = next(iter(train_loader))
//next()函数:不断返回迭代器下一个值
//iter()函数:把list,dict,str等可迭代的对象Iterable(可以用for循环的对象)转换为迭代器Iterator可以使用
img = utils.make_grid(images
img = img.numpy().transpose(1,2,0) 
std = [0.5]
mean = [0.5]
img = img * std + mean
for i in range(64):
   print(lables[i], end=" ")
   i += 1
   if i%8 is 0:
       print(end='\n')
plt.
  • 8
    点赞
  • 86
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值