使用深度学习进行手写数字识别

 一、全连接神经网络进行手写数字识别

1.1加载和处理数据

从本地文件加载MNIST数据集,然后对数据进⾏预处理。MNIST数据集中的⼿写数字被表

⽰为28x28的向量,我们将其展开为⼀个784维的特征向量。此外,我们还需要对输⼊数据进⾏归⼀化,通常将像素值从 0-255缩放到0-1的范围内,此外,我们还需要把标签转换为one-hot形式 ,并且将原始训练集划分为训练集验证集和测试集。

1.2构建全连接模型

输入层:

没有显式定义输入层,但输入数据X被假定为一个矩阵,其维度是(m,input_size),其中m是批量大小,input_size是输入数据的特征数量

隐藏层:

显式定义了一个隐藏层,其大小由参数hidden_size决定。隐藏层的权重W1是一个(input_size, hidden_size)维的矩阵,偏置b1是一个(1, hidden_size) 维的向量

激活函数层:

隐藏层后面跟着一个激活函数层,激活函数由Activation参数决定,可以是 ReLU、sigmoid、tanh 或 leaky_relu。激活函数的作用是引入非线性,帮助网络学习更复杂的模式

输出层:

输出层的权重W2是一个(hidden_size, output_size)维的矩阵,偏置b2是一个(1, output_size) 维的向量。这里的output_size通常是类别的数量,对于MNIST数据集来说是10(代表数字0到9)

输出激活函数:

输出层后面应用了一个 softmax 激活函数,用于将输出层的原始分数转换为概率分布,使得所有类别的预测概率之和为1

1.3训练模型

设置批次batch_ = 6000,学习率learning_rate = 0.001,轮次num_epochs = 100,使用训练集训练模型,验证集⽤于调整模型的超参数和评估模型的性能,模型的训练2分钟即可完成

1.4评估模型

测试集是模型从未见过的数据集,使用测试集可以评估模型的性能和泛化能力

1.5训练结果

模型训练集上有 98.83%的准确率,在测试集上有97.56%的准确率。

总结:这个代码使用了全连接网络进行了手写数字识别,准确率较为满意,使用模型训练时间(未使用GPU加速)也比较快,适合初学者学习,大家也可以调整一下超参数看看能不能取得更好的成功率,同时作者也在文件里面放了详细的介绍。

参考代码:https://github.com/rascal-yang/MLP.git

二、使用残差网络进行手写数字识别

2.1.残差网络相较于传统网络优势优势

目录

1.2构建全连接模型

1.3训练模型

1.4评估模型

1.5训练结果

2.1.残差网络相较于传统网络优势优势

2.1.1.残差学习

2.1.2.恒等映射

2.1.3.深层网络

2.1.4.退化问题解决

2.2读取数据

2.3损失函数和优化方式

2.4超参数设置

2.5 环境配置和训练模型

2.6保存模型

2.7实验结果

2.8额外功能

2.8.1粉笔图片识别

2.8.2 摄像头实时识别

2.9参数调整

3.0 GPU加速


2.1.1.残差学习

ResNet的核心思想是让网络中的每个残差块(Residual Block)学习输入和输出之间的残差(即差异),而不是直接学习映射。这意味着网络的每个部分只需要在现有层的基础上进行微调,而不是从头开始学习整个映射

2.1.2.恒等映射

在残差块中,如果输入和输出相同,那么网络可以简单地通过恒等映射(即直接将输入复制到输出)来实现,而不需要学习任何复杂的变换。这使得网络可以更容易地学习到恒等函数,从而避免了梯度消失或爆炸的问题。

2.1.3.深层网络

由于残差学习框架的引入,ResNet可以构建非常深的网络结构(例如,100层或更深),而不会遭受传统深层网络中的退化问题。这使得网络能够学习到更复杂的特征表示,从而提高性能。

2.1.4.退化问题解决

在传统的深度神经网络中,随着层数的增加,梯度可能会消失或爆炸,导致网络难以训练。ResNet通过残差连接(Skip Connection)直接连接前面的层和后面的层,有效地解决了这个问题

参考文献:ResNet-18超详细介绍!!!!_resnet18-CSDN博客

2.2读取数据

读取IDX文件:

IDX文件是一种二进制格式,用于存储FishionMint数据集中的图像和标签。代码中定义了两个函数 decode_idx3_ubyte 和 decode_idx1_ubyte 来读取图像和标签文件

读取图像数据 (decode_idx3_ubyte):

这个函数读取图像数据文件(.idx3-ubyte),它首先读取整个文件到二进制数据 bin_data,

使用 struct.unpack_from 来解析文件头部信息,包括魔术数(magic number)、图像数量、图像行数和列数。然后,函数计算每个图像数据的偏移量,并使用 struct.unpack_from 读取每个图像的像素值,将其转换为28x28的NumPy数组

读取标签数据 (decode_idx1_ubyte):

类似于读取图像数据,这个函数读取标签数据文件(.idx1-ubyte),

解析头部信息,获取标签数量,

然后逐个读取每个标签,并存储到NumPy数组中

加载数据的辅助函数:

load_train_images 和 load_train_labels 函数用于加载训练集的图像和标签

load_test_images 和 load_test_labels 函数用于加载测试集的图像和标签

2.3损失函数和优化方式

loss = torch.nn.CrossEntropyLoss()  #模型使用的损失函数是交叉熵损失,这是多分类问题中常用的损失函数

使用了Adam优化器,这是一种常用的优化器,它自动调整学习率

2.4超参数设置

学习率被设置为 0.001,批次大小被设置为 1000,训练轮数(epoch)被设置为 100。

2.5 环境配置和训练模型

在hand_wrtten_train.py里面训练模型,只需要简单修改一下导入数据那几行代码的地址,将绝对路径换成相对路径或者修改为自己的绝对路径即可,主要的包有:numpy、struct、matplotlib、OpenCV、Pytorch、torchvision、tqdm,使用pip命令下载即可

2.6保存模型

模型会保存在log文件下

2.7实验结果

训练集准确率:100%

测试集准确率:99.3%

实验结果可见残差神经网络相较于前一个网络的优势是明显提升的,这里训练集达到100%不能简单认为模型过拟合,因为手写数字图片比较简单,模型在测试集上有99.3的准确率,可以说明模型表现良好。

2.8额外功能

2.8.1粉笔图片识别

使用了20张粉笔图片进行识别,分别在两个文件real_img和real_img_resize两个文件里面,每个文件个各九张图片,两个文件的九张图片均为1到9。

注意事项:main_pthoto.py文件中图片路径需要修改为自己的路径以及predict.py文件中权重片路径需要修改为自己在训练得到的模型保存文件路径。

实验结果:在real_img_resize文件的九张图片未识别出数字3,而real_img文件有三张图片识别错误,由于粉笔图片与训练集和测试集的图片有所不同,所以模型的效果还是令人满意。

2.8.2 摄像头实时识别

代码存在于main.py文件下,使用方法运行即可打开摄像头进行识别

2.9参数调整

学习率调整:将原学习率由0.001调整为0.01后在第70个epoch的模型训练集准确率100%,测试集99.2%,几乎不变,在real_img_resize文件的九张粉笔数字图片全部识别成功,real_img文件的数字4未识别成功,较原来的模型有所提升。

批次调整:将批次由原来的1000改为100后在第70个epoch的模型训练集准确率100%,测试集99.5%,变化很小,在两个文件的18张图片测试未识别成功5张图片,较原来模型有所下降。

3.0 GPU加速

由于这篇代码只使用CPU训练时间特别长,同时代码也设置了可用GPU训练,所以我这里放一个简单安装GPU的链接:全网最详细的安装pytorch GPU方法,一次安装成功!!包括安装失败后的处理方法!-CSDN博客

总结:

相较于上一篇代码,这篇使用的 FishionMint 数据 集是 MNIST 的一个替代品,图像内容是时尚商品,而不是手写数字,且每个类别的图像 数量更平均,这使得它在某些方面比 MNIST 更具挑战性, 模型一般在经过 40epoch 时损失函数就可以及其接近 0 了,在训练集上高达 100 之百的准确率,在测试集上的准确率高达百分之 99.3,实现 了测试集准确率超百分之 90。

总之,这篇代码不仅实现了测试集上百分之 90 的成功率,并且还增加了对图片的识别,

提高了模型的泛化能力,还实现了实时检测功能。

参考CSDN 地址:基于卷积神经网络的手写数字识别(附数据集+完整代码+操作说明)-CSDN博客

参考GitHub 代码地址:GitHub - Hurri-cane/Hand_wrtten at master

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值