1. LeNet 简介
LeNet 原是 LeNet1 - LeNet5 这系列网络的合称,但现在所说的 LeNet 则一般特指 LeNet5 (后文统一称为 LeNet)。LeNet 是 Yann LeCun 教授于 1998 年在论文《Gradient-Based Learning Applied to Document Recognition》中提出的 ,设计之初只是用于手写数字的识别,到如今已成为卷积神经网络的 HelloWorld。受限于计算机的算力不足,加之支持向量机 (核学习方法) 的兴起,CNN 方法并未成为当时学术界认可的主流方法。
算上输入层的话,LeNet 共有 8 层,包含 3 个卷积层,2 个池化层 (下采样层) 和 1 个全连接层。其中,所有卷积操作的核都固定为 5x5,步长为 1;统一使用全局平均池化。LeNet 的网络结构如下
![](https://img-blog.csdnimg.cn/20201205163029599.png)
![](https://img-blog.csdnimg.cn/4d8791bace6b42b1a18f4405871c7bfa.png)
- 输入层:输入图像的尺寸为 32X32;
- C1 层 (卷积层):使用 6 个核大小为 5×5 的卷积,得到 6 张 28×28 的特征图;
- S2 层 (池化层,即下采样层):使用 6 个 2×2 的平均池化,得到 6 张 14×14 的特征图;
- C3 层 (卷积层):使用 16 个核大小为 5×5 的卷积,得到 16 张 10×10 的特征图;
- S4 层 (池化层):使用 16 个 2×2 的平均池化,得到 16 张 5×5 的特征图;
- C5 层 (卷积层):使用 120 个核大小为 5×5 的卷积,得到 120 张 1×1 的特征图 (一个向量);
- F6 层 (全连接层):含 84 个节点的全连接层,对应于一个 7x12 的比特图;
- 输出层:含 10 个节点的全连接层,分别代表数字 0 到 9。
2. LeNet 的 PyTorch 实现
# _*_coding:utf-8_*_
import torch
import torch.nn as nn
import torch.nn.functional as F
class LeNet5(nn.Module):
def __init__(self, in_channels, out_channels):
super(LeNet5, self).__init__()
# 卷积神经网络
self.features = nn.Sequential(
nn.Conv2d(in_channels, 6, kernel_size=5),
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(6, 16, kernel_size=5),
nn.MaxPool2d(kernel_size=2) # 原模型使用的是平均池化
)
# 分类器
self.classifier = nn.Sequential(
nn.Linear(16 * 5 * 5, 120 * 1 * 1), # 这里将第三个卷积层看成是全连接层
nn.Linear(120, 84),
nn.Linear(84, out_channels)
)
def forward(self, x):
x = self.features(x)
x = torch.flatten(x, 1) # 铺平tensor
x = self.classifier(x)
out = F.softmax(x, 1) # 激活函数
return out
if __name__ == "__main__":
batch_size = 1
in_channels = 1
out_channels = 10
inputs = torch.rand((batch_size, in_channels, 32, 32)) # (B, C, H, W)
lenet = LeNet5(in_channels, out_channels)
outputs = lenet(inputs)
print(outputs)
print(outputs.sum())
print(outputs.max())
print(torch.argmax(outputs, 1))
【参考】