一、问题描述
使用PyTorch实现LeNet,进行简单的图片分类。
使用LeNet训练CIFAR10数据集,实现输入一张RGB图片,输出该图片包含对象的类别。
二、代码实现
2.1 文件结构
- data: 训练数据集地址
- models: 分类模型地址
- LeNet.py: LeNet模型搭建
- weights: 模型权重存放地址
- example: 测试模型所使用的图片
- train.py: 对模型进行训练
- predict.py: 测试模型的分类效果
2.2 环境配置
- Python3.10
- PyTorch2.0.0
- Ubuntu22.04
2.3 模型实现
LeNet.py
import torch.nn as nn
import torch.nn.functional as func
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = func.relu(self.conv1(x)) # input: (3, 32, 32) output: (16, 28, 28)
x = self.pool1(x) # output: (16, 14, 14)
x = func.relu(self.conv2(x)) # output: (32, 10, 10)
x = self.pool2(x) # output: (32, 5, 5)
x = x.view(-1, 32 * 5 * 5) # output: (32*5*5)
x = func.relu(self.fc1(x)) # output: (120)
x = func.relu(self.fc2(x)) # output: (84)
x = self.fc3(x) # output: (10)
return x
2.4 结果展示
测试图片
测试结果:
2.5 项目地址
https://github.com/piggy-wanger/LeNet