我在做计算智能课的结课大论文,请你结合深度学习、机器学习和计算智能的知识,及其相关知识,帮助我完成本次结课大论文。另,本次实验采用MATLAB R2024a的实验环境。
任务五:利用ResNet网络训练MNIST数据集(20分)
[简述ResNet网络的原理]
[说明ResNet网络结构及重要参数设置]
[实验结果展示]
[实验结果分析及可改进方向]
[代码展示]
我现在在完成[代码展示]部分的内容,为我下面给出的代码解决运行结果中的报错,并给我解决报错后的完整代码。。
代码:
%% 任务五:最终可运行ResNet-MNIST识别系统
% 修复标签格式问题,确保100%兼容性
clear; clc; close all;
rng(2024, 'twister'); % 随机种子策略
%% 兼容数据加载方案
fprintf('使用兼容数据加载方案...\n');
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...
'nndatasets', 'DigitDataset');
% 训练集加载
trainImds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders', true, 'LabelSource', 'foldernames');
[trainImds, testImds] = splitEachLabel(trainImds, 0.7, 'randomized');
% 转换为4D数组格式
XTrain = readall(trainImds);
if iscell(XTrain)
XTrain = cat(4, XTrain{:});
end
YTrain = trainImds.Labels; % 直接使用分类标签
% 测试集加载
XTest = readall(testImds);
if iscell(XTest)
XTest = cat(4, XTest{:});
end
YTest = testImds.Labels; % 直接使用分类标签
% 确保灰度图像(单通道)
if size(XTrain, 3) == 3
XTrain = rgb2gray(XTrain);
XTest = rgb2gray(XTest);
end
if size(XTrain, 3) == 1
XTrain = reshape(XTrain, [size(XTrain,1), size(XTrain,2), 1, size(XTrain,4)]);
XTest = reshape(XTest, [size(XTest,1), size(XTest,2), 1, size(XTest,4)]);
end
% 统一尺寸为28x28
if size(XTrain,1) ~= 28 || size(XTrain,2) ~= 28
XTrain = imresize(XTrain, [28, 28]);
XTest = imresize(XTest, [28, 28]);
end
fprintf('数据集加载完成: 训练集%d样本, 测试集%d样本\n', ...
size(XTrain,4), size(XTest,4));
%% 数据增强(兼容方案)
augmenter = imageDataAugmenter(...
'RandRotation', [-15 15], ...
'RandXTranslation', [-3 3], ...
'RandYTranslation', [-3 3]);
imdsTrain = augmentedImageDatastore([28 28 1], XTrain, YTrain, ...
'DataAugmentation', augmenter);
%% 修复的纯顺序结构残差网络
layers = [
% === 输入层 ===
imageInputLayer([28 28 1], 'Name', 'input', 'Normalization', 'none')
% === 初始卷积 ===
convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'conv1')
batchNormalizationLayer('Name', 'bn1')
reluLayer('Name', 'relu1')
% === 残差块1 ===
% 主路径
convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'res1_conv1')
batchNormalizationLayer('Name', 'res1_bn1')
reluLayer('Name', 'res1_relu1')
convolution2dLayer(3, 16, 'Padding', 'same', 'Name', 'res1_conv2')
batchNormalizationLayer('Name', 'res1_bn2')
% 残差连接(通过1x1卷积实现加法)
convolution2dLayer(1, 16, 'Name', 'res1_add', ...
'WeightsInitializer', @(sz) 2 * reshape(eye(16), [1,1,16,16]), ... % 修复的权重初始化
'BiasInitializer', 'zeros', ...
'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重
batchNormalizationLayer('Name', 'res1_add_bn')
reluLayer('Name', 'res1_final_relu')
% === 残差块2(带下采样)===
% 主路径
convolution2dLayer(3, 32, 'Padding', 'same', 'Stride', 2, 'Name', 'res2_conv1')
batchNormalizationLayer('Name', 'res2_bn1')
reluLayer('Name', 'res2_relu1')
convolution2dLayer(3, 32, 'Padding', 'same', 'Name', 'res2_conv2')
batchNormalizationLayer('Name', 'res2_bn2')
% 残差连接(带下采样)
convolution2dLayer(1, 32, 'Stride', 2, 'Name', 'res2_shortcut')
batchNormalizationLayer('Name', 'res2_bn_shortcut')
% 加法操作
convolution2dLayer(1, 32, 'Name', 'res2_add', ...
'WeightsInitializer', @(sz) 2 * reshape(eye(32), [1,1,32,32]), ... % 修复的权重初始化
'BiasInitializer', 'zeros', ...
'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重
batchNormalizationLayer('Name', 'res2_add_bn')
reluLayer('Name', 'res2_final_relu')
% === 残差块3 ===
% 主路径
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'res3_conv1')
batchNormalizationLayer('Name', 'res3_bn1')
reluLayer('Name', 'res3_relu1')
convolution2dLayer(3, 64, 'Padding', 'same', 'Name', 'res3_conv2')
batchNormalizationLayer('Name', 'res3_bn2')
% 残差连接
convolution2dLayer(1, 64, 'Name', 'res3_shortcut', ...
'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重
batchNormalizationLayer('Name', 'res3_bn_shortcut')
% 加法操作
convolution2dLayer(1, 64, 'Name', 'res3_add', ...
'WeightsInitializer', @(sz) 2 * reshape(eye(64), [1,1,64,64]), ... % 修复的权重初始化
'BiasInitializer', 'zeros', ...
'WeightLearnRateFactor', 0, 'BiasLearnRateFactor', 0) % 固定权重
batchNormalizationLayer('Name', 'res3_add_bn')
reluLayer('Name', 'res3_final_relu')
% === 分类部分 ===
globalAveragePooling2dLayer('Name', 'gap')
fullyConnectedLayer(10, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'output')
];
%% 训练配置
options = trainingOptions('adam', ...
'InitialLearnRate', 0.01, ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropPeriod', 10, ...
'LearnRateDropFactor', 0.7, ...
'MaxEpochs', 25, ...
'MiniBatchSize', 128, ...
'Shuffle', 'every-epoch', ...
'ValidationData', {XTest, YTest}, ... % 使用分类标签
'ValidationFrequency', 100, ...
'Verbose', true, ...
'Plots', 'training-progress', ...
'ExecutionEnvironment', 'cpu');
%% 模型训练
fprintf('开始训练网络...\n');
net = trainNetwork(imdsTrain, layers, options);
%% 模型评估
fprintf('评估模型性能...\n');
tic;
[YPred, probs] = classify(net, XTest, 'ExecutionEnvironment', 'cpu');
inferenceTime = toc;
accuracy = mean(YPred == YTest);
fprintf('测试准确率: %.4f%%\n', accuracy*100);
fprintf('总推理时间: %.2f秒 | 单样本: %.4f毫秒\n', ...
inferenceTime, inferenceTime*1000/size(XTest,4));
%% 结果可视化
% 混淆矩阵
figure;
confusionchart(YTest, YPred);
title(sprintf('ResNet-MNIST (准确率: %.4f%%)', accuracy*100));
% 样本预测展示
figure;
numSamples = 9;
randIndices = randperm(size(XTest,4), numSamples);
for i = 1:numSamples
subplot(3,3,i);
img = XTest(:,:,:,randIndices(i));
imshow(img, []);
predLabel = char(YPred(randIndices(i)));
trueLabel = char(YTest(randIndices(i)));
if strcmp(predLabel, trueLabel)
color = 'g';
else
color = 'r';
end
title(sprintf('真实: %s | 预测: %s', trueLabel, predLabel), 'Color', color);
end
%% 模型保存
save('ResNet_MNIST_Final.mat', 'net', 'accuracy', 'inferenceTime');
fprintf('模型已保存为ResNet_MNIST_Final.mat\n');
运行结果如下:
运行错误:
使用兼容数据加载方案...
数据集加载完成: 训练集7000样本, 测试集3000样本
开始训练网络...
|=============================================================================|
| 轮 | 迭代 | 经过的时间 | 小批量准确度 | 验证准确度 | 小批量损失 | 验证损失 | 基础学习率 |
| | | (hh:mm:ss) | | | | | |
|=============================================================================|
| 1 | 1 | 00:00:36 | 4.69% | 12.43% | 2.3997 | 2.3882 | 0.0100 |
| 1 | 50 | 00:01:38 | 92.97% | | 0.3559 | | 0.0100 |
| 2 | 100 | 00:02:41 | 97.66% | 97.53% | 0.0924 | 0.1026 | 0.0100 |
| 3 | 150 | 00:03:36 | 96.88% | | 0.0899 | | 0.0100 |
| 4 | 200 | 00:04:42 | 99.22% | 99.30% | 0.0434 | 0.0338 | 0.0100 |
| 5 | 250 | 00:05:40 | 100.00% | | 0.0092 | | 0.0100 |
| 6 | 300 | 00:06:45 | 98.44% | 99.47% | 0.0354 | 0.0271 | 0.0100 |
| 7 | 350 | 00:07:37 | 96.09% | | 0.1189 | | 0.0100 |
| 8 | 400 | 00:08:37 | 99.22% | 99.03% | 0.0411 | 0.0391 | 0.0100 |
| 9 | 450 | 00:09:31 | 99.22% | | 0.0212 | | 0.0100 |
| 10 | 500 | 00:10:29 | 98.44% | 99.00% | 0.0283 | 0.0367 | 0.0100 |
| 11 | 550 | 00:12:01 | 100.00% | | 0.0092 | | 0.0070 |
| 12 | 600 | 00:13:30 | 100.00% | 99.73% | 0.0111 | 0.0078 | 0.0070 |
| 13 | 650 | 00:14:48 | 100.00% | | 0.0087 | | 0.0070 |
| 13 | 700 | 00:16:20 | 100.00% | 99.90% | 0.0098 | 0.0044 | 0.0070 |
| 14 | 750 | 00:17:40 | 100.00% | | 0.0019 | | 0.0070 |
| 15 | 800 | 00:19:18 | 100.00% | 99.90% | 0.0015 | 0.0050 | 0.0070 |
| 16 | 850 | 00:20:45 | 100.00% | | 0.0022 | | 0.0070 |
| 17 | 900 | 00:22:17 | 100.00% | 99.93% | 0.0043 | 0.0072 | 0.0070 |
| 18 | 950 | 00:23:48 | 99.22% | | 0.0134 | | 0.0070 |
| 19 | 1000 | 00:25:17 | 100.00% | 99.80% | 0.0051 | 0.0097 | 0.0070 |
| 20 | 1050 | 00:26:37 | 98.44% | | 0.0413 | | 0.0070 |
| 21 | 1100 | 00:28:02 | 100.00% | 99.30% | 0.0116 | 0.0224 | 0.0049 |
| 22 | 1150 | 00:29:21 | 100.00% | | 0.0073 | | 0.0049 |
| 23 | 1200 | 00:30:36 | 100.00% | 99.87% | 0.0055 | 0.0048 | 0.0049 |
| 24 | 1250 | 00:31:05 | 100.00% | | 0.0023 | | 0.0049 |
| 25 | 1300 | 00:31:45 | 100.00% | 99.93% | 0.0050 | 0.0028 | 0.0049 |
| 25 | 1350 | 00:32:18 | 100.00% | 99.97% | 0.0051 | 0.0023 | 0.0049 |
|=============================================================================|
训练结束: 已完成最大轮数。
评估模型性能...
测试准确率: 99.9000%
总推理时间: 14.32秒 | 单样本: 4.7727毫秒
不支持将脚本 confusionchart 作为函数执行:
C:\Program
Files\MATLAB\R2024a\toolbox\shared\mlearnlib\confusionchart.m
出错 untitled (第 168 行)
confusionchart(YTest, YPred);
>>