【小白也可做】pytorch实现普通CNN的MNIST手写数字分类,含t-SNE聚类图、混淆矩阵图绘制,内含通用代码,可根据自己的项目需要进行修改

一、项目简介

本项目是基于pytorch使用两层CNN网络实现手写数字的分类识别,并且绘制了损失和准确率曲线、训练前后的t-SNE聚类图、混淆矩阵图。

二、数据集

常用的手写数字数据集MNIST,这个大家自行百度就有很多说明这个数据集的文章啦,里面的图片大概是长这样子的。

三、实验环境

平台:Window 11

语言:python3.9

编译器:Pycharm

框架:Pytorch:1.13.1

四、实验内容及部分代码展示

1、Model.py 模型构建

该项目使用的网络包含2维卷积、池化层、全连接层,通过ReLU激活函数进行非线性变换

2、train.py 用于分类的训练通用模板

3、Config.py 参数定义

config类中定义了项目所有需要的参数,可以在里面修改训练参数。

4、mnist_class_cnn_run.py 运行文件

该py文件实现整体训练流程并做绘图操作。依次实现加载数据、数据格式转化、划分训练集测试集、形成数据更迭器、载入模型、定义损失、定义优化器、开始训练、损失可视化、显示预测结果。

5、test_pth.py 模型训练后的测试文件

采用模型训练完成后的pth对数据进行预测,可以展示模型预测效果,前面对数据的处理过程类似mnist_class_cnn_run.py所示。

6、draw_loss_acc.py 模型训练后的loss绘图

将训练后产生并收集的损失loss.csv和准确率acc.csv展示出来,也就是损失和准确率变化曲线。

7、tsne_plot.py 模型训练后用于绘制t-SNE聚类图

绘制了训练前和训练后样本的t-SNE聚类图

8、matrix_plot.py 绘制模型训练后的混淆矩阵

五、实验效果分析

1、loss损失图

该损失是训练了50个epoch的损失图

2、acc准确率图

3、test_pth.py的预测效果展示 

其中红框是预测有误的,有误的概率比较小

4、matrix_plot.py的混淆矩阵效果展示 

第一个图是训练了第一个epoch后的混淆矩阵,第二个图是训练了50个epoch的混淆矩阵,横坐标是预测值,纵坐标是真实值,中间的数字指的是样本数,比如第一个数字976指的是有真实值为0的976个样本预测出是0,同一行后面有4个真实值是0的样本预测的是6和8。相比于第1个和第50个epoch的混淆矩阵中的数据,对角线处其对应正确预测的样本数越来越大,说明模型训练有效果。

5、matrix_plot.py的混淆矩阵效果展示 

第一个是未训练时的样本聚类图,第二个是训练后的样本聚类图,模型训练后,各数字的分布明显分隔开了,说明模型对数字识别分类有效果。

六、资源与总结

若有朋友需要可运行的源码和数据集,可以guan注【科研小条】公众号,回复【手写数字分类】,即可获得。

  • 8
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

科研小条

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值