LeNet-5实现手写数字识别

一、手写数字识别原理

模型的输入数据是包含手写数字信息的二维图像,将其输入到网络模型中,经过模型的前向计算得到输出的识别结果,通过损失函数度量计算输出结果与输入图像标签的差异度,并通过反向传播算法根据这个差异来调整网络各层的参数值,经过反复迭代输入,最终得到一个能准确识别输入图像的网络模型(输出的识别结果与标签一致)

二、网络结构

LetNet-5网络结构如下图,输入的二维图像经过两个卷积层、一个池化层,再经过三个全连接层将前层计算得到的特征空间映射样本标记空间,最后使用softmax分类作为输出层。

各层参数:

输入大小:28*28*1

卷积层①:5*5*6,padding=2

池化层:2*2最大池化

卷积层②:5*5*16,padding=0

全连接层①:输入400、输出120

全连接层②:输入120、输出84

全连接层③:输入84、输出10

输出层:softmax

三、代码实现

import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision

"""MNIST数据集"""
train_dataset = torchvision.datasets.MNIST("dataset", train=True,
                                           transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = train_dataset

# DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=128,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=100,
                                          shuffle=False)

"""LetNet-5"""
class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)  # padding = 2, 28+2+2=32
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.pool = nn.MaxPool2d(2)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(400, 120)  # 400=16*5*5
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.softmax = nn.Softmax()

    def forward(self, x):
        in_size = x.size(0)
        out = self.relu(self.pool(self.conv1(x)))
        out = self.relu(self.pool(self.conv2(out)))
        out = out.view(in_size, -1)
        out = self.relu(self.fc1(out))
        out = self.relu(self.fc2(out))
        out = self.fc3(out)
        return self.softmax(out,)


model = LeNet5()

# 损失函数:交叉熵损失
loss_func = torch.nn.CrossEntropyLoss()

# 定义优化器
opt = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    for batch_index, (data, target) in enumerate(train_loader):
        data, target = Variable(data), Variable(target)
        opt.zero_grad()  # backward前梯度清零
        output = model(data)
        loss = loss_func(output, target)
        # 误差反向传播
        loss.backward()
        # 参数更新
        opt.step()
        if batch_index % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_index * len(data), len(train_loader.dataset),
                       100. * batch_index / len(train_loader), loss.item()))


def test():
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data, target = Variable(data, volatile=True), Variable(target)
        output = model(data)
        # 叠加loss
        test_loss += loss_func(output, target).item()
        # 最大概率预测结果标签
        pred = torch.max(output.data, 1)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()
        test_loss /= len(test_loader.dataset)

        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))


# 迭代10轮后测试
for epoch in range(1, 11):
    train(epoch)

test()

四、实验结果

经过10轮训练,将训练后的网络模型用来测试原始数据集,60000张图像中识别正确的数量为59239,识别正确率为99%。

  • 7
    点赞
  • 69
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
LeNet-5是一种经典的卷积神经网络,用于手写数字识别。它在1998年由Yann LeCun等人提出,旨在通过学习感知到的局部特征来实现数字的自动识别和分类。 LeNet-5主要由两个重要部分组成:卷积神经网络(CNN)和全连接层。 输入图像首先经过两个卷积层和池化层,用于提取图像的特征。卷积层通过滑动窗口计算每个窗口中的特征,然后池化层对特征图进行降采样,减少计算量和参数个数。随后,通过几个全连接层对提取的特征进行分类,最终输出层得到识别结果。 在训练阶段,LeNet-5使用反向传播算法来更新网络权重,最小化训练样本与目标标签之间的损失函数。该损失函数可衡量网络对不同数字的分类准确性。 为了识别手写数字'c',我们需要准备一组训练样本包含手写数字'c'的图像及其标签,并将这些样本输入LeNet-5进行训练。训练过程中,网络将学习到特定于'c'的特征,以便能够准确地区分出'c'与其他数字。 完成训练后,我们可以用测试集对LeNet-5进行评估。将手写数字'c'的图像输入网络,根据输出层的预测结果即可进行识别判断。如果网络的输出结果与'c'标签匹配,则说明LeNet-5成功地识别了手写数字'c'。 总而言之,LeNet-5是一种使用卷积神经网络实现手写数字识别的经典模型。通过训练和调整网络权重,LeNet-5能够识别手写数字'c'。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

bluebub

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值