前言
以下文章均为学习笔记,目的是加强自己的记忆,同时希望帮助更多的学习者理解视频中的内容
是跟着一位优秀的b站up主霹雳吧啦Wz学习的
附上视频链接:(2.1 pytorch官方demo(Lenet)_哔哩哔哩_bilibili
另外笔记是参考另一位博主,小白刚开始写笔记
附上文章链接:(5条消息) pytorch图像分类篇:2.pytorch官方demo实现一个分类器(LeNet)_Fun’的博客-CSDN博客_pytorch图像分类
如有侵权联系我删除
pytorch官网入门demo——LeNet图像分类器
基于模型:LeNet
通过官方给的数据集可以看到,输入的数据格式为3×32×32
demo流程
- model —— 构建LeNet网络结构
- train —— 训练模型
- test —— 模型测试
1、model()
先给出代码
注意:pytorch框架中的Tensor通道排序为[batch_size, channel, height, width]
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self