目录
1. Lenet5网络结构
1.1 输入层
输入32 * 32 大小的图像
1.2 C1卷积层
输入: 32 * 32
卷积核大小: 5 * 5
卷积核个数: 6
输出featuremap大小:28 * 28 (32-5+1)=28
输出数量为:6
1.3 S2池化层(下采样层)
输入: 28 * 28
采样区域 : 2 * 2
采样方式:4个输入相加,乘以一个可训练参数,再加上一个可训练偏置,得到结果通过sigmoid 函数
采样种类:6
输出featuremap大小: 14 * 14 (28/2)
输出数量为:6
1.4 C3 卷积层
输入: 14 * 14
卷积核大小: 5 * 5
卷积核个数: 16
输出featuremap大小:10 * 10 ( 14 - 5 + 1)= 10
输出数量为:16
1.5 S4 池化层(下采样层)
输入: 10 * 10
采样区域 : 2 * 2
采样方式:4个输入相加,乘以一个可训练参数,再加上一个可训练偏置,得到结果通过sigmoid 函数
采样种类:16
输出featuremap大小: 5 * 5 (10/2)
输出数量为:16
1.6 C5 卷积层
输入: 5 * 5
卷积核大小: 5 * 5
卷积核个数: 120
输出featuremap大小:1 * 1 (5-5+1)=1
输出数量为:120
1.7 F6 全连接层
输入:c5 120维向量
计算方式:计算输入向量和权重向量之间的点积,再加上一个偏置,结果通过sigmoid函数输出。
F6层是全连接层。F6层有84个节点,对应于一个7x12的比特图,该层的训练参数和连接数都是(120 + 1)x84=10164
1.8 Output 层-全连接层
Output层也是全连接层,共有10个节点,分别代表数字0到9,如果节点i的输出值为0,则网络识别的结果是数字i。采用的是径向基函数(RBF)的网络连接方式。
2. pytorch代码
import torch
from torch import nn
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x: [b, 3, 32, 32] => [b, 16, ]
nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=5, stride=2, padding=0),
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
)
# flatten
# fc_unit
self.fc_unit = nn.Sequential(
nn.Linear(16 * 4 * 4, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 10),
)
# test conv_unit out dim
# [b, 3, 32, 32]
tmp = torch.randn(2, 3, 32, 32)
out = self.conv_unit(tmp)
# [b, 16, 5, 5]
print('conv out:', out.shape)
# # use Cross Entropy Loss
# self.criteon = nn.CrossEntropyLoss()
def forward(self, x):
"""
:param x: [b, 3, 32, 32]
:return:
"""
batchsz = x.size(0)
# [b, 3, 32, 32] => [b, 16, 4, 4]
x = self.conv_unit(x)
# flatten
x = x.view(batchsz, 16 * 4 * 4)
# [b, 16*4*4] => [b, 10]
logits = self.fc_unit(x)
# # [b, 10]
# pred = F.softmax(logits, dim=1)
# loss = self.criteon(logits, y)
return logits
def main():
net = Lenet5()
tmp = torch.randn(2, 3, 32, 32)
out = net(tmp)
print('lenet out:', out.shape)
if __name__ == '__main__':
main()
参考:
https://blog.csdn.net/budong282712018/article/details/102684216