LeNet-5 手写数字识别
一、前言
-
LeNet-5出自论文Gradient-Based Learning Applied to Document Recognition,1998年由Yann LeCun提出,是一种用于手写体字符识别的非常高效的卷积神经网络。
-
MNIST 具有50000个训练数据
-
10000个测试数据
-
图像大小为28*28
-
10类
-
使用Sigmoid激活函数
二、LeNet网络架构
LeNet-5的基本结构包括7层网络结构(不含输入层),其中包括2个卷积层、2个降采样层(池化层)、2个全连接层和输出层。
1、输入层(Input layer)
输入层接收大小为32*32的手写数字图像,其中包括灰度值(0-255)。在实际应用中,我们通常会对输入图像进行预处理,例如对像素值进行归一化,以加快训练速度和提高模型的准确性。
2、卷积层C1(Convolutional layer C1)
卷积层C1包括6个卷积核,每个卷积核的大小为5 * 5,步长为1,填充为0。因此,每个卷积核会产生一个大小为28*28的特征图(输出通道数为6)。
问题一:如何得出产生的特征图尺寸大小?
记住这几个符号:
-
H:图片高度;
-
W:图片宽度;
-
D:原始图片通道数,也是卷积核个数;
-
F:卷积核高宽大小;
-
P:pad图像边扩充大小;
-
S:stride滑动步长。
由输入数据矩阵的尺寸W1 * H1 * D1 (输入层:32 * 32 * 5) ,求输出特征图组尺寸W2 * H2 * D2 ,公式如下:
W2与H2一般是相等的,带入公式W2=(32-5+2*0)/1+1=28。
-
featuremap为:28 * 28
-
神经元数量:28 * 28 * 6=4704
-
可训练参数:(5 * 5+1) * 6(每个滤波器5*5=25个unit参数和一个bias参数,一共6个滤波器)
-
连接数:(5 * 5+1)* 6 * 28 * 28=122304
详细说明:对输入图像进行第一次卷积运算(使用 6 个大小为 5 * 5 的卷积核),得到6个C1特征图(6个大小为28 * 的 feature maps, 32-5+1=28)。我们再来看看需要多少个参数,卷积核的大小为5 * 5,总共就有6*(5 * 5+1)=156个参数,其中+1是表示一个核有一个bias。对于卷积层C1,C1内的每个像素都与输入图像中的5 * 5个像素和1个bias有连接,所以总共有1562828=122304个连接(connection)。有122304个连接,但是我们只需要学习156个参数,主要是通过权值共享实现的
3、采样层S2(下采样)(Subsampling layer S2)
采样层S2采用最大池化(max-pooling)操作,每个窗口的大小为2 * 2,步长为2。因此,每个池化操作会从4个相邻的特征图中选择最大值,产生一个大小为14 * 14的特征图(输出通道数为6)。这样可以减少特征图的大小,提高计算效率,并且对于轻微的位置变化可以保持一定的不变性。
输出特征图尺寸大小 W2=28/2=14。
-
featuremap为:14 * 14
-
神经元数量:14 * 14 * 6
-
连接数:(2 * 2+1)* 6 * 14 * 14
4、卷积层C3(Convolutional layer C3)
卷积层C3包括16个卷积核,每个卷积核的大小为5 * 5,步长为1,填充为0。因此,每个卷积核会产生一个大小为10 * 10的特征图(输出通道数为16)。
输出特征图尺寸大小 W2=(14-5+2 * 0)/1+1=10。
-
featuremap为:10 * 10
5、采样层S4(Subsampling layer S4)
采样层S4采用最大池化操作,每个窗口的大小为2 * 2,步长为2。因此,每个池化操作会从4个相邻的特征图中选择最大值,产生一个大小为5 * 5的特征图(输出通道数为16)。
输出特征图尺寸大小 W2=10/2=5。
-
featuremap为:5 * 5
6、全连接层C5(Fully connected layer C5)
C5将每个大小为5 * 5的特征图拉成一个长度为400(5 * 5 * 16)的向量,并通过一个带有120个神经元的全连接层进行连接。120是由LeNet-5的设计者根据实验得到的最佳值。
-
featuremap为:1 * 1
7、全连接层F6(Fully connected layer F6)
全连接层F6将120个神经元连接到84个神经元。
8、输出层(Output layer)
输出层由10个神经元组成,每个神经元对应0-9中的一个数字,并输出最终的分类结果。在训练过程中,使用交叉熵损失函数计算输出层的误差,并通过反向传播算法更新卷积核和全连接层的权重参数。
然而,在实际应用中,通常会对LeNet-5进行一些改进,例如增加网络深度、增加卷积核数量、添加正则化等方法,以进一步提高模型的准确性和泛化能力。
三、基于pytorch的LeNet-5代码实现
# 1################加载库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义LeNet-5模型
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=84)
self.fc3 = nn.Linear(in_features=84, out_features=10)
def forward(self, x):
x = self.pool1(torch.relu(self.conv1(x)))
x = self.pool2(torch.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
# 2##########加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
# 定义数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
# 定义模型、损失函数和优化器
model = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# 3################训练模型
for epoch in range(10):
for i, (images, labels) in enumerate(train_loader):
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, 10, i + 1, len(train_loader),
loss.item()))
# 4######################测试模型
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy: {:.2f}%'.format(100 * correct / total))
运行结果:
D:\Anaconda\envs\pytorch\python.exe E:/pythoncode/深度学习/LeNet-5.py
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 11053874.52it/s]
Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<?, ?it/s]
Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 6208088.59it/s]
Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<?, ?it/s]
Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw
Epoch [1/10], Step [100/938], Loss: 0.6447
Epoch [1/10], Step [200/938], Loss: 0.1975
Epoch [1/10], Step [300/938], Loss: 0.2640
Epoch [1/10], Step [400/938], Loss: 0.3751
Epoch [1/10], Step [500/938], Loss: 0.0927
Epoch [1/10], Step [600/938], Loss: 0.0837
Epoch [1/10], Step [700/938], Loss: 0.2943
Epoch [1/10], Step [800/938], Loss: 0.1807
Epoch [1/10], Step [900/938], Loss: 0.0968
Epoch [2/10], Step [100/938], Loss: 0.1481
Epoch [2/10], Step [200/938], Loss: 0.3320
Epoch [2/10], Step [300/938], Loss: 0.1445
Epoch [2/10], Step [400/938], Loss: 0.0987
Epoch [2/10], Step [500/938], Loss: 0.0935
Epoch [2/10], Step [600/938], Loss: 0.1353
Epoch [2/10], Step [700/938], Loss: 0.0133
Epoch [2/10], Step [800/938], Loss: 0.1472
Epoch [2/10], Step [900/938], Loss: 0.1204
Epoch [3/10], Step [100/938], Loss: 0.0228
Epoch [3/10], Step [200/938], Loss: 0.1182
Epoch [3/10], Step [300/938], Loss: 0.0279
Epoch [3/10], Step [400/938], Loss: 0.0076
Epoch [3/10], Step [500/938], Loss: 0.0838
Epoch [3/10], Step [600/938], Loss: 0.0086
Epoch [3/10], Step [700/938], Loss: 0.0728
Epoch [3/10], Step [800/938], Loss: 0.0883
Epoch [3/10], Step [900/938], Loss: 0.0394
Epoch [4/10], Step [100/938], Loss: 0.0184
Epoch [4/10], Step [200/938], Loss: 0.0496
Epoch [4/10], Step [300/938], Loss: 0.1083
Epoch [4/10], Step [400/938], Loss: 0.0623
Epoch [4/10], Step [500/938], Loss: 0.1034
Epoch [4/10], Step [600/938], Loss: 0.0296
Epoch [4/10], Step [700/938], Loss: 0.0713
Epoch [4/10], Step [800/938], Loss: 0.0128
Epoch [4/10], Step [900/938], Loss: 0.0999
Epoch [5/10], Step [100/938], Loss: 0.0611
Epoch [5/10], Step [200/938], Loss: 0.0081
Epoch [5/10], Step [300/938], Loss: 0.0309
Epoch [5/10], Step [400/938], Loss: 0.0875
Epoch [5/10], Step [500/938], Loss: 0.0992
Epoch [5/10], Step [600/938], Loss: 0.1341
Epoch [5/10], Step [700/938], Loss: 0.0084
Epoch [5/10], Step [800/938], Loss: 0.0211
Epoch [5/10], Step [900/938], Loss: 0.0425
Epoch [6/10], Step [100/938], Loss: 0.0648
Epoch [6/10], Step [200/938], Loss: 0.0218
Epoch [6/10], Step [300/938], Loss: 0.1469
Epoch [6/10], Step [400/938], Loss: 0.0219
Epoch [6/10], Step [500/938], Loss: 0.0035
Epoch [6/10], Step [600/938], Loss: 0.0815
Epoch [6/10], Step [700/938], Loss: 0.0152
Epoch [6/10], Step [800/938], Loss: 0.0890
Epoch [6/10], Step [900/938], Loss: 0.0083
Epoch [7/10], Step [100/938], Loss: 0.0223
Epoch [7/10], Step [200/938], Loss: 0.0326
Epoch [7/10], Step [300/938], Loss: 0.1326
Epoch [7/10], Step [400/938], Loss: 0.0012
Epoch [7/10], Step [500/938], Loss: 0.0485
Epoch [7/10], Step [600/938], Loss: 0.0120
Epoch [7/10], Step [700/938], Loss: 0.0195
Epoch [7/10], Step [800/938], Loss: 0.0552
Epoch [7/10], Step [900/938], Loss: 0.0032
Epoch [8/10], Step [100/938], Loss: 0.0080
Epoch [8/10], Step [200/938], Loss: 0.0208
Epoch [8/10], Step [300/938], Loss: 0.0148
Epoch [8/10], Step [400/938], Loss: 0.0023
Epoch [8/10], Step [500/938], Loss: 0.0120
Epoch [8/10], Step [600/938], Loss: 0.0059
Epoch [8/10], Step [700/938], Loss: 0.0256
Epoch [8/10], Step [800/938], Loss: 0.0281
Epoch [8/10], Step [900/938], Loss: 0.0489
Epoch [9/10], Step [100/938], Loss: 0.0617
Epoch [9/10], Step [200/938], Loss: 0.0120
Epoch [9/10], Step [300/938], Loss: 0.0022
Epoch [9/10], Step [400/938], Loss: 0.0465
Epoch [9/10], Step [500/938], Loss: 0.0199
Epoch [9/10], Step [600/938], Loss: 0.0099
Epoch [9/10], Step [700/938], Loss: 0.0119
Epoch [9/10], Step [800/938], Loss: 0.0767
Epoch [9/10], Step [900/938], Loss: 0.0076
Epoch [10/10], Step [100/938], Loss: 0.0310
Epoch [10/10], Step [200/938], Loss: 0.0056
Epoch [10/10], Step [300/938], Loss: 0.0735
Epoch [10/10], Step [400/938], Loss: 0.0262
Epoch [10/10], Step [500/938], Loss: 0.0060
Epoch [10/10], Step [600/938], Loss: 0.0318
Epoch [10/10], Step [700/938], Loss: 0.0109
Epoch [10/10], Step [800/938], Loss: 0.0428
Epoch [10/10], Step [900/938], Loss: 0.0005
Test Accuracy: 98.72%
进程已结束,退出代码0
四、浅谈LeNet-5贡献
LeNet-5在当时的手写数字识别任务中取得了很好的效果,可以达到98%以上的准确率,这是当时最先进的技术水平。它的成功证明了深度学习的潜力,吸引了更多研究者加入到深度学习的研究中。同时,LeNet-5也为后来更加复杂的卷积神经网络奠定了基础,例如AlexNet、VGG、ResNet等。这些网络都采用了类似LeNet-5的卷积神经网络结构,但增加了更多的层数和参数,从而在图像分类、目标检测等任务中取得了更好的效果。虽然LeNet-5在当今深度学习的发展中已经不再是最先进的技术,但它的经典结构和训练方法仍然对深度学习的发展和应用有重要意义。