哔哩哔哩教程:【卷积神经网络数字图片识别(MATLAB)】https://www.bilibili.com/video/BV1n1421y7kZ?vd_source=471188e4bb1a707946690b75051d386c
一、搭建卷积神经网络模型(数字图片识别模型)
% 步骤1:准备数据
digitDatasetPath = fullfile(matlabroot,'toolbox','nnet','nndemos', ...
'nndatasets','DigitDataset');
digitData = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders',true,'LabelSource','foldernames');
% 数据增强
augmenter = imageDataAugmenter( ...
'RandRotation',[-10,10], ... % 随机旋转角度范围为 -10 到 10 度
'RandXTranslation',[-3,3], ... % 随机水平平移范围为 -3 到 3 像素
'RandYTranslation',[-3,3]); % 随机垂直平移范围为 -3 到 3 像素
digitDataAugmented = augmentedImageDatastore([28 28],digitData, ...
'DataAugmentation',augmenter);
% 划分数据集
[trainImgs,testImgs] = splitEachLabel(digitData,0.7,'randomized');
% 步骤2:定义网络结构
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer
];
% 步骤3:训练网络
options = trainingOptions('sgdm', ...
'MaxEpochs',5, ...
'MiniBatchSize',64, ...
'InitialLearnRate',0.01, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress');
net = trainNetwork(trainImgs,layers,options);
% 步骤4:测试网络
predictedLabels = classify(net,testImgs);
accuracy = sum(predictedLabels == testImgs.Labels) / numel(testImgs.Labels);
fprintf('测试准确率: %.2f%%\n', accuracy * 100);
% 保存训练好的网络
save('trainedDigitNet.mat', 'net');
二、调用模型函数代码编写
% 加载训练好的网络
load('trainedDigitNet.mat', 'net');
% 读取输入的数字图片
inputImage = imread('C:\Users\19350\Desktop\920.jpg'); % 替换为你的图片路径
inputImage = rgb2gray(inputImage); % 将彩色图像转换为灰度图像
inputImage = imresize(inputImage, [28, 28]); % 将图像大小调整为28x28像素
% 使用模型进行预测
predictedLabel = classify(net, inputImage);
% 显示预测结果
disp(['识别结果为: ', char(predictedLabel)]);
三、创建APP可视化卷积神经网络数字图片识别界面
classdef DigitRecognizerApp < matlab.apps.AppBase
% Properties that correspond to app components
properties (Access = public)
UIFigure matlab.ui.Figure
LoadImageButton matlab.ui.control.Button
UIAxes matlab.ui.control.UIAxes
RecognitionResultLabel matlab.ui.control.Label
end
properties (Access = private)
net % Pre-trained network
end
methods (Access = private)
% Button pushed function: LoadImageButton
function LoadImageButtonPushed(app, event)
[file, path] = uigetfile({'*.jpg;*.png;*.bmp', 'Image Files'});
if isequal(file, 0)
return;
end
imagePath = fullfile(path, file);
inputImage = imread(imagePath);
inputImage = rgb2gray(inputImage);
inputImage = imresize(inputImage, [28, 28]);
% Display the image
imshow(inputImage, 'Parent', app.UIAxes);
% Use the trained network to classify the image
predictedLabel = classify(app.net, inputImage);
% Display the result
app.RecognitionResultLabel.Text = ['识别结果为: ', char(predictedLabel)];
end
end
methods (Access = public)
% Construct app
function app = DigitRecognizerApp
% Create and configure components
createComponents(app)
% Load the pre-trained network
app.net = load('trainedDigitNet.mat', 'net');
app.net = app.net.net;
end
% Code that executes before app deletion
function delete(app)
% Delete UIFigure when app is deleted
delete(app.UIFigure)
end
end
methods (Access = private)
% Create and configure components
function createComponents(app)
% Create UIFigure
app.UIFigure = uifigure;
app.UIFigure.Position = [100 100 400 300];
app.UIFigure.Name = '数字识别应用';
% Create LoadImageButton
app.LoadImageButton = uibutton(app.UIFigure, 'push');
app.LoadImageButton.ButtonPushedFcn = createCallbackFcn(app, @LoadImageButtonPushed, true);
app.LoadImageButton.Position = [150 220 100 30];
app.LoadImageButton.Text = '加载图像';
% Create UIAxes
app.UIAxes = uiaxes(app.UIFigure);
app.UIAxes.Position = [50 50 150 150];
% Create RecognitionResultLabel
app.RecognitionResultLabel = uilabel(app.UIFigure);
app.RecognitionResultLabel.Position = [220 50 150 30];
app.RecognitionResultLabel.Text = '识别结果为: ';
end
end
end
四、最后结果显示