mnist torch加载fashion_梦幻:Pytorch使用记录3——MNIST

一、网络构建

整个网络的结构表述如下:

对于网络的前向传播,首先是进行一次卷积,由1维转化为20维,进行ReLU激活;之后最大池化,卷积,由20维转化为50维,进行ReLU激活;最后两层进行全连接,得到10个概率输出,最后进行一次Softmax的对数运算,方便后面交叉熵损失的计算。

MNIST初始图像的大小是28*28,经过大小为5的卷积核和步长为1的卷积操作后,大小变为24*24(没有padding),池化操作亦是如此,图像的大小变化如下:28*28 → 24*24 → 12*12 → 8*8 → 4*4.

卷积和池化都遵循着下面的公式:

其中n代表卷积(池化)前的图像大小,p代表padding值,f为核大小,s为步长。

二、读取数据

和Tensorflow一样,Pytorch有着方便的内建函数来下载并读取常用的数据集,MNIST就是其中之一,代码如下:

此处已经将数据分为32大小的批次为小批量数据梯度下降做好了数据准备(batch_size和test_batch_size已经在其他位置定义),对于MNIST,Pytorch会将其以6 : 1的比例切分成训练集和测试集。

这里的train_loader和test_loader都是列表,其中放置了两个张量,一个是图像的张量(4维),一个是一一对应的标签张量,均为32个,这些信息对于我们自己测试数据集之外的其他数据很有必要。(见补充)

归一化参数可以由上面的式子求得(平均数和标准差),或者论文中查到。

三、训练数据

只需要两个批次的迭代即可达到很好的训练效果。

这里我们使用到了批量梯度下降,并且将运算过程搬迁到了GPU上进行,这将大幅加快整个训练过程的速度。

有几个要点:首先是Pytorch这个库的优势,它可以用一个函数完成整个网络的反向求导,这避免了繁重的代码量(模型参数默认都会进行求导);然后是求得的梯度每次需要进行清零,否则会造成梯度的累加;保存检查点便于额外测试。

训练的到的10个概率我们选择最大概率作为输出结果来与标签比对计算损失;这里的损失函数采用的是交叉熵损失。

四、训练结果及测试

我们经过了两轮的训练就可以达到近乎99%的精度,这其实得益于前人的测试。如果我们选择FashionMNIST再用相同的参数,仅迭代两轮的话损失难以降到0.01以下,准确率仅为85%左右,解决方法有增加迭代次数以及微调参数。下面仅展示MNIST的迭代过程:

最后一个测试就是用数据集之外的数据进行测试,手写数字,通过OpenCV转变成与MNIST类似的单通道28*28图像,并且以PIL读入(不用OpenCV的原因是Pytorch是可以直接实现PIL和张量之间的转化的),效果如下:

上面的代码实现的就是图像变换,下面就是需要将其重整为我们的模型可以读入的形式,具体方案如下:

转化操作与之前大同小异,但是这里多了一步提升维度的操作,因为之前的批量读入多了一个维度就是图片张数(32张),我们即使是一张图片,也要对应进行维度提升。接下来进行模型的重载,将我们的模型和检查点的数据匹配:

最终结果如下:

补充:

下面是一些数据格式的表述:

首先,我们要进行额外测试的时候,需要仿照MNIST数据集格式,可以打印MNIST数据元素:

单个数据是一个张量,由图片和标签构成,图片的张量维度是三维:

但是,如果我们把单张图片直接输入网络会报这样的错误:

这是因为loader中的数据是批量导入,又多了一个维度:

图像和标签已经被放在两个张量(其中32个图像张量构成一个大张量,维数变为4维)中,他们又构成一个张量。

因此我们需要提升维度:

最终结果是(转移到GPU上):

参考链接:

Pytorch官方文档https://pytorch.org/docs/stable/index.html

图像的加载与读取:https://www.jianshu.com/p/cfca9c4338e7

Tensor与图像的相互转化:https://blog.csdn.net/qq_36955294/article/details/82888443

张量的维度操作:https://blog.csdn.net/weicao1990/article/details/93618136

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值