中文MNIST数据集的图像分类(准确度99.93%)

该博客介绍了如何使用Python和TensorFlow对中文MNIST数据集进行图像分类。内容包括数据集的介绍、数据预处理、训练集与测试集的划分、数据归一化、引入预训练模型以及模型的训练和性能评估,最终实现高达99.93%的分类准确率。

数据集

链接:

Chinese MNIST | KaggleChinese numbers handwritten characters imageshttps://www.kaggle.com/gpreda/chinese-mnist

简介:

中国版的 MNIST 数据集是在纽卡斯尔大学的一个项目框架中收集的数据。一百名中国公民参与了数据收集工作。 每个参与者用标准的黑色墨水笔在一张桌子上写下所有 15 个数字,在一张白色 A4 纸上画出了 15 个指定区域。 这个过程对每个参与者重复 10 次。 每张纸都以 300x300 像素的分辨率扫描。结果返回一个包含 15000 个图像的数据集,每个图像代表一组 15 个字符中的一个字符。

代码

引入相关类库

natsort是一个用于排序的类库,为什么这么多的排序不用,偏偏使用这个,因为它的排序规则与Windows的文件排序规则一致!因为csv里面的标签与图片是分离的,所以需要自己先找到办法,把图片和标签正确对应起来。

!pip install natsort #排序规则与Windows的文件排序规则一致
import pandas as pd
import numpy as np
import sys
import os
import tensorflow as tf
from pathlib import Path
import sklearn
from sklearn.model_selection import train_test_split
from tensorflow import keras
import warnings
from natsort import ns, natsorted
warnings.filterwarnings('ignore')

读入数据 & 排序

csv文件的排序处理

把csv文件,按照 'suite_id', 'sample_id', 'code' 先后进行升序排列,得到的就有规律的排列情况,将会与后面的图片排序,一一对应。

data_df = pd.read_csv('../input/chinese-mnist/chinese_mnist.csv') #读入csv文件
data_df.sort_values(by=['suite_id','sample_id','code'], ascending=True, inplace=True) 
#按照 'suite_id', 'sample_id', 'code' 先后进行升序排列
data_df = data_df.reset_index(drop=True) #使索引按照新的排序排列,并丢弃旧的索引
data_df[:20] #显示前20行

显示前20行

图片的排序处理

接下来对图片按照Window的排序规则进行排序

image_dir = Path('../input/chinese-mnist/data/data') #获取图片的根目录
image_paths = list(image_dir.glob('*.jpg')) #获取所有图片的位置
image_paths = natsorted(image_paths, alg=ns.PATH) #按照windows的规则排序
image_paths = pd.Series(image_paths, name='Image_path').astype(str) #拼接成csv文件
image_paths[:20] #展示前20行

展示前20行

从csv文件中取得标签,与图片的位置拼接成新的csv文件。

labels = data_df['code'].astype(str) #需要转成字符串类型,不然会报错
image_df = pd.concat([image_paths, labels], axis=1) #拼接标签与图片的位置
image_df.rename(columns={'code': 'Label'}, inplace=True) #对列名重命名
image_df[:20] #展示前20行

展示前20行

运行错误,而且图像也没显示出来。请修改代码。 运行结果如下: 使用兼容数据加载方案… 数据集加载完成: 训练集7000样本, 测试集3000样本 开始训练网络… |=============================================================================| |  轮  |  迭代  |    经过的时间     |  小批量准确度  |  验证准确度  |  小批量损失  |  验证损失  |  基础学习率  | |     |      |  (hh:mm:ss)  |          |         |         |        |         | |=============================================================================| |   1 |    1 |     00:00:10 |    4.69% |  12.43% |  2.3997 | 2.3882 |  0.0100 | |   1 |   50 |     00:01:04 |   92.97% |         |  0.3559 |        |  0.0100 | |   2 |  100 |     00:01:59 |   97.66% |  97.53% |  0.0924 | 0.1026 |  0.0100 | |   3 |  150 |     00:02:47 |   96.88% |         |  0.0899 |        |  0.0100 | |   4 |  200 |     00:03:42 |   99.22% |  99.30% |  0.0434 | 0.0338 |  0.0100 | |   5 |  250 |     00:04:29 |  100.00% |         |  0.0092 |        |  0.0100 | |   6 |  300 |     00:05:23 |   98.44% |  99.47% |  0.0354 | 0.0271 |  0.0100 | |   7 |  350 |     00:06:05 |   96.09% |         |  0.1189 |        |  0.0100 | |   8 |  400 |     00:06:50 |   99.22% |  99.03% |  0.0411 | 0.0391 |  0.0100 | |   9 |  450 |     00:07:29 |   99.22% |         |  0.0212 |        |  0.0100 | |  10 |  500 |     00:08:14 |   98.44% |  99.00% |  0.0283 | 0.0367 |  0.0100 | |  11 |  550 |     00:08:52 |  100.00% |         |  0.0092 |        |  0.0070 | |  12 |  600 |     00:09:36 |  100.00% |  99.73% |  0.0111 | 0.0078 |  0.0070 | |  13 |  650 |     00:10:14 |  100.00% |         |  0.0087 |        |  0.0070 | |  13 |  700 |     00:10:55 |  100.00% |  99.90% |  0.0098 | 0.0044 |  0.0070 | |  14 |  750 |     00:11:33 |  100.00% |         |  0.0019 |        |  0.0070 | |  15 |  800 |     00:12:17 |  100.00% |  99.90% |  0.0015 | 0.0050 |  0.0070 | |  16 |  850 |     00:12:54 |  100.00% |         |  0.0022 |        |  0.0070 | |  17 |  900 |     00:13:38 |  100.00% |  99.93% |  0.0043 | 0.0072 |  0.0070 | |  18 |  950 |     00:14:15 |   99.22% |         |  0.0134 |        |  0.0070 | |  19 | 1000 |     00:14:59 |  100.00% |  99.80% |  0.0051 | 0.0097 |  0.0070 | |  20 | 1050 |     00:15:36 |   98.44% |         |  0.0413 |        |  0.0070 | |  21 | 1100 |     00:16:20 |  100.00% |  99.30% |  0.0116 | 0.0224 |  0.0049 | |  22 | 1150 |     00:16:57 |  100.00% |         |  0.0073 |        |  0.0049 | |  23 | 1200 |     00:17:41 |  100.00% |  99.87% |  0.0055 | 0.0048 |  0.0049 | |  24 | 1250 |     00:18:19 |  100.00% |         |  0.0023 |        |  0.0049 | |  25 | 1300 |     00:19:02 |  100.00% |  99.93% |  0.0050 | 0.0028 |  0.0049 | |  25 | 1350 |     00:19:46 |  100.00% |  99.97% |  0.0051 | 0.0023 |  0.0049 | |=============================================================================| 训练结束: 已完成最大轮数。 评估模型性能… 测试准确率: 99.9000% 总推理时间: 2.31秒 | 单样本: 0.7706毫秒 混淆矩阵函数路径: C:\Program Files\MATLAB\R2024a\toolbox\shared\mlearnlib\confusionchart.m 不支持将脚本 confusionchart 作为函数执行: C:\Program Files\MATLAB\R2024a\toolbox\shared\mlearnlib\confusionchart.m 出错 untitled (第 174 行) confusionchart(YTest, YPred, …
06-27
我在做计算智能课的结课大论文,请你结合深度学习、机器学习和计算智能的知识,及其相关知识,帮助我完成本次结课大论文。另,本次实验采用MATLAB R2024a的实验环境。 任务五:利用ResNet网络训练MNIST数据集(20分) [简述ResNet网络的原理] [说明ResNet网络结构及重要参数设置] [实验结果展示] [实验结果分析及可改进方向] [代码展示] 我现在在完成[代码展示]部分的内容,为我下面给出的代码解决运行结果中的报错,并给我解决报错后的完整代码。。 代码: %% 任务五:最终可运行ResNet-MNIST识别系统 % 修复标签格式问题,确保100%兼容性 clear; clc; close all; rng(2024, 'twister'); % 随机种子策略 %% 兼容数据加载方案 fprintf('使用兼容数据加载方案...\n'); digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ... 'nndatasets', 'DigitDataset'); % 训练集加载 trainImds = imageDatastore(digitDatasetPath, ... 'IncludeSubfolders', true, 'LabelSource', 'foldernames'); [trainImds, testImds] = splitEachLabel(trainImds, 0.7, 'randomized'); % 转换为4D数组格式 XTrain = readall(trainImds); if iscell(XTrain) XTrain = cat(4, XTrain{:}); end YTrain = trainImds.Labels; % 直接使用分类标签 % 测试集加载 XTest = readall(testImds); if iscell(XTest) XTest = cat(4, XTest{:}); end YTest = testImds.Labels; % 直接使用分类标签 % 确保灰度图像(单通道) if size(XTrain, 3) == 3 XTrain = rgb2gray(XTrain); XTest = rgb2gray(XTest); end if size(XTrain, 3) == 1 XTrain = reshape(XTrain, [size(XTrain,1), size(XTrain,2), 1, size(XTrain,4)]); XTest = reshape(XTest, [size(XTest,1), size(XTest,2), 1, size(XTest,4)]); end % 统一尺寸为28x28 if size(XTrain,1) ~= 28 || size(XTrain,2) ~= 28 XTrain = imresize(XTrain, [28, 28]); XTest = imresize(XTest, [28, 28]); end fprintf('数据集加载完成: 训练集%d样本, 测试集%d样本\n', ... size(XTrain,4), size(XTest,4)); %% 数据增强(兼容方案) augmenter = imageDataAugmenter(... 'RandRotation', [-15 15], ... 'RandXTranslation', [-3 3], ... 'RandYTranslation', [-3 3]); imdsTrain = augmentedImageDatastore([28 28 1], XTrain, YTrain, ... 'DataAugmentation', augmenter); %% 修复的纯顺序结构残差网络 layers = [ % === 输入层 === imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none') % === 初始卷积 === convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'conv1') batchNormalizationLayer('Name', 'bn1') reluLayer('Name', 'relu1') % === 残差块1 === % 主路径 convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'res1_conv1') batchNormalizationLayer('Name', 'res1_bn1') reluLayer('Name', 'res1_relu1') convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'res1_conv2') batchNormalizationLayer('Name', 'res1_bn2') % 残差连接(通过1x1卷积实现加法) convolution2dLayer(1, 16, 'Name', 'res1_add', ... 'WeightsInitializer', @(sz) 2 * reshape(eye(16), [1,1,16,16]), ... % 修复的权重初始化 'BiasInitializer', 'zeros', ... 'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重 batchNormalizationLayer('Name', 'res1_add_bn') reluLayer('Name', 'res1_final_relu') % === 残差块2(带下采样)=== % 主路径 convolution2dLayer(3, 32, 'Padding', 'same', 'Stride', 2, 'Name', 'res2_conv1') batchNormalizationLayer('Name', 'res2_bn1') reluLayer('Name', 'res2_relu1') convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'res2_conv2') batchNormalizationLayer('Name', 'res2_bn2') % 残差连接(带下采样) convolution2dLayer(1, 32, 'Stride', 2, 'Name', 'res2_shortcut') batchNormalizationLayer('Name', 'res2_bn_shortcut') % 加法操作 convolution2dLayer(1, 32, 'Name', 'res2_add', ... 'WeightsInitializer', @(sz) 2 * reshape(eye(32), [1,1,32,32]), ... % 修复的权重初始化 'BiasInitializer', 'zeros', ... 'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重 batchNormalizationLayer('Name', 'res2_add_bn') reluLayer('Name', 'res2_final_relu') % === 残差块3 === % 主路径 convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'res3_conv1') batchNormalizationLayer('Name', 'res3_bn1') reluLayer('Name', 'res3_relu1') convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'res3_conv2') batchNormalizationLayer('Name', 'res3_bn2') % 残差连接 convolution2dLayer(1, 64, 'Name', 'res3_shortcut', ... 'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重 batchNormalizationLayer('Name', 'res3_bn_shortcut') % 加法操作 convolution2dLayer(1, 64, 'Name', 'res3_add', ... 'WeightsInitializer', @(sz) 2 * reshape(eye(64), [1,1,64,64]), ... % 修复的权重初始化 'BiasInitializer', 'zeros', ... 'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重 batchNormalizationLayer('Name', 'res3_add_bn') reluLayer('Name', 'res3_final_relu') % === 分类部分 === globalAveragePooling2dLayer('Name', 'gap') fullyConnectedLayer(10, 'Name', 'fc') softmaxLayer('Name', 'softmax') classificationLayer('Name', 'output') ]; %% 训练配置 options = trainingOptions('adam', ... 'InitialLearnRate', 0.01, ... 'LearnRateSchedule', 'piecewise', ... 'LearnRateDropPeriod', 10, ... 'LearnRateDropFactor', 0.7, ... 'MaxEpochs', 25, ... 'MiniBatchSize', 128, ... 'Shuffle', 'every-epoch', ... 'ValidationData', {XTest, YTest}, ... % 使用分类标签 'ValidationFrequency', 100, ... 'Verbose', true, ... 'Plots', 'training-progress', ... 'ExecutionEnvironment', 'cpu'); %% 模型训练 fprintf('开始训练网络...\n'); net = trainNetwork(imdsTrain, layers, options); %% 模型评估 fprintf('评估模型性能...\n'); tic; [YPred, probs] = classify(net, XTest, 'ExecutionEnvironment', 'cpu'); inferenceTime = toc; accuracy = mean(YPred == YTest); fprintf('测试准确率: %.4f%%\n', accuracy*100); fprintf('总推理时间: %.2f秒 | 单样本: %.4f毫秒\n', ... inferenceTime, inferenceTime*1000/size(XTest,4)); %% 结果可视化 % 混淆矩阵 figure; confusionchart(YTest, YPred); title(sprintf('ResNet-MNIST (准确率: %.4f%%)', accuracy*100)); % 样本预测展示 figure; numSamples = 9; randIndices = randperm(size(XTest,4), numSamples); for i = 1:numSamples subplot(3,3,i); img = XTest(:,:,:,randIndices(i)); imshow(img, []); predLabel = char(YPred(randIndices(i))); trueLabel = char(YTest(randIndices(i))); if strcmp(predLabel, trueLabel) color = 'g'; else color = 'r'; end title(sprintf('真实: %s | 预测: %s', trueLabel, predLabel), 'Color', color); end %% 模型保存 save('ResNet_MNIST_Final.mat', 'net', 'accuracy', 'inferenceTime'); fprintf('模型已保存为ResNet_MNIST_Final.mat\n'); 运行结果如下: 运行错误: 使用兼容数据加载方案... 数据集加载完成: 训练集7000样本, 测试集3000样本 开始训练网络... |=============================================================================| |  轮  |  迭代  |    经过的时间     |  小批量准确度  |  验证准确度  |  小批量损失  |  验证损失  |  基础学习率  | |     |      |  (hh:mm:ss)  |          |         |         |        |         | |=============================================================================| |   1 |    1 |     00:00:36 |    4.69% |  12.43% |  2.3997 | 2.3882 |  0.0100 | |   1 |   50 |     00:01:38 |   92.97% |         |  0.3559 |        |  0.0100 | |   2 |  100 |     00:02:41 |   97.66% |  97.53% |  0.0924 | 0.1026 |  0.0100 | |   3 |  150 |     00:03:36 |   96.88% |         |  0.0899 |        |  0.0100 | |   4 |  200 |     00:04:42 |   99.22% |  99.30% |  0.0434 | 0.0338 |  0.0100 | |   5 |  250 |     00:05:40 |  100.00% |         |  0.0092 |        |  0.0100 | |   6 |  300 |     00:06:45 |   98.44% |  99.47% |  0.0354 | 0.0271 |  0.0100 | |   7 |  350 |     00:07:37 |   96.09% |         |  0.1189 |        |  0.0100 | |   8 |  400 |     00:08:37 |   99.22% |  99.03% |  0.0411 | 0.0391 |  0.0100 | |   9 |  450 |     00:09:31 |   99.22% |         |  0.0212 |        |  0.0100 | |  10 |  500 |     00:10:29 |   98.44% |  99.00% |  0.0283 | 0.0367 |  0.0100 | |  11 |  550 |     00:12:01 |  100.00% |         |  0.0092 |        |  0.0070 | |  12 |  600 |     00:13:30 |  100.00% |  99.73% |  0.0111 | 0.0078 |  0.0070 | |  13 |  650 |     00:14:48 |  100.00% |         |  0.0087 |        |  0.0070 | |  13 |  700 |     00:16:20 |  100.00% |  99.90% |  0.0098 | 0.0044 |  0.0070 | |  14 |  750 |     00:17:40 |  100.00% |         |  0.0019 |        |  0.0070 | |  15 |  800 |     00:19:18 |  100.00% |  99.90% |  0.0015 | 0.0050 |  0.0070 | |  16 |  850 |     00:20:45 |  100.00% |         |  0.0022 |        |  0.0070 | |  17 |  900 |     00:22:17 |  100.00% |  99.93% |  0.0043 | 0.0072 |  0.0070 | |  18 |  950 |     00:23:48 |   99.22% |         |  0.0134 |        |  0.0070 | |  19 | 1000 |     00:25:17 |  100.00% |  99.80% |  0.0051 | 0.0097 |  0.0070 | |  20 | 1050 |     00:26:37 |   98.44% |         |  0.0413 |        |  0.0070 | |  21 | 1100 |     00:28:02 |  100.00% |  99.30% |  0.0116 | 0.0224 |  0.0049 | |  22 | 1150 |     00:29:21 |  100.00% |         |  0.0073 |        |  0.0049 | |  23 | 1200 |     00:30:36 |  100.00% |  99.87% |  0.0055 | 0.0048 |  0.0049 | |  24 | 1250 |     00:31:05 |  100.00% |         |  0.0023 |        |  0.0049 | |  25 | 1300 |     00:31:45 |  100.00% |  99.93% |  0.0050 | 0.0028 |  0.0049 | |  25 | 1350 |     00:32:18 |  100.00% |  99.97% |  0.0051 | 0.0023 |  0.0049 | |=============================================================================| 训练结束: 已完成最大轮数。 评估模型性能... 测试准确率: 99.9000% 总推理时间: 14.32秒 | 单样本: 4.7727毫秒 不支持将脚本 confusionchart 作为函数执行: C:\Program Files\MATLAB\R2024a\toolbox\shared\mlearnlib\confusionchart.m 出错 untitled (第 168 行) confusionchart(YTest, YPred); >>
06-27
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Lord12Snow3

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值