项目路径:D:\pycharm_code\template\pilibala_template\deep-learning-for-image-processing-master\pytorch_classification\Test1_official_demo
现在pytorch官方有关Lenet的demo已经下架了
(1)注意:pytorch tensor的通道排序为[batch,channel,height,width],本代码中的数据集为彩色图像,为RGB三种颜色,故input是32*32*3
(2)LeNet类继承自父类nn.Module,super函数解决的是继承过程中出现的问题
代码与上面结构图是类似的对应,并不是完全一致的对应
1.输入为RGB彩图,为3通道,设置通过第一个卷积层输出16通道,由于第一层卷积核为5*5,则输入32*32,得到输出为28*28的;
2.下采样,通过2*2,stride=2的池化核,得到(16,14,14)
3.再通过nn.Conv2d(16, 32, 5),得到(32,10,10)
4.下采样,再通过2*2,stride=2的池化核,得到(32,5,5)
5.通过全连接层需要将(32,5,5)的向量变成一维向量,即32*5*5作为nn.Linear的in_features,对照Lenet网络结构图,得知其out_features为120,故参数为(32*5*5,120)
6.后面的全连接层同理,注意到最后一个全连接层的输出参数为10,因为所采用的分类数据集为Cifar10,分为10类
(3)在train.py中,首次下载cifar10的训练集,需要将download设置为ture
注意:这里import torchvision标红了是因为没有选对python解释器,记得选一下
后面的trian_loader的num_worker只有为0时,在Windows环境下不会报错。
输入torchvision.datasets可以看到pytorch官方有很多数据集。
(4)验证一下
先反归一化,再将图像由tensor转化为numpy格式,由于之前采用了to tensor,而to tensor将(H,W,C)转化为了(C,H,W),因此需要将现在(1,2,0)才能变回(H,W,C)
(5)
(6)再在predict.py文件验证训练后的模型参数对一个全新的图片1.jpg的分类效果
(7)cifar10的readme
在每个epoch下要有两段结果展示,如[1,500],[1,1000],因为cifar10中验证集有10000张图片,而val_loader的batchsize为5000。