前言
在 MATLAB 中进行深度学习网络训练主要分为数据准备、网络构建、训练配置和模型评估四个核心步骤。以下是详细教程:
二、数据准备
- 加载与组织数据
% 图像数据(以MNIST为例)
digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', ...
'nndemos', 'nndatasets', 'DigitDataset');
imds = imageDatastore(digitDatasetPath, ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
% 文本数据(以IMDB影评为例)
tbl = readtable('imdb_reviews.csv');
documents = tokenizedDocument(tbl.Review);
- 数据增强与划分
% 图像数据增强
augmenter = imageDataAugmenter(...
'RandRotation', [-15, 15], ...
'RandXReflection', true);
augimds = augmentedImageDatastore([224 224 3], imds, ...
'DataAugmentation', augmenter);
% 划分训练/验证/测试集
[imdsTrain, imdsVal, imdsTest] = partitionCamVidData(imds, ...
'TrainSet', 0.7, 'ValSet', 0.15);
```
# 二、网络构建
1. 自定义网络(以 CNN 为例)
```c
layers = [
imageInputLayer([224 224 3])
convolution2dLayer(3, 16, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
convolution2dLayer(3, 32, 'Padding', 'same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2, 'Stride', 2)
fullyConnectedLayer(128)
reluLayer
dropoutLayer(0.5)
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer
];
- 使用预训练模型(迁移学习)
% 加载预训练ResNet-50
net = resnet50;
% 修改网络结构(替换最后几层)
lgraph = layerGraph(net);
lgraph = removeLayers(lgraph, {'fc1000', 'prob', 'ClassificationLayer_Predictions'});
lgraph = addLayers(lgraph, [
fullyConnectedLayer(numClasses, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'classoutput')
]);
% 连接层
lgraph = connectLayers(lgraph, 'avg_pool', 'fc');
三、训练配置与执行
- 设置训练选项
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.001, ...
'MaxEpochs', 10, ...
'MiniBatchSize', 64, ...
'Shuffle', 'every-epoch', ...
'ValidationData', imdsVal, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress', ...
'LearnRateSchedule', 'piecewise', ...
'LearnRateDropFactor', 0.1, ...
'LearnRateDropPeriod', 5);
```
2. 执行训练
```c
% 训练网络
net = trainNetwork(imdsTrain, layers, options);
% 迁移学习训练(冻结部分层)
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.0001, ...
'LayersToFreeze', 'all-10'); % 冻结前10层
net = trainNetwork(imdsTrain, lgraph, options);
四、模型评估与优化
- 性能评估
% 在测试集上评估
YPred = classify(net, imdsTest);
YTest = imdsTest.Labels;
accuracy = mean(YPred == YTest);
fprintf('测试集准确率: %.2f%%\n', accuracy*100);
% 混淆矩阵
cm = confusionmat(YTest, YPred);
figure
confusionchart(cm, categories(YTest));
- 模型优化
% 超参数优化
hyperparams = struct(...
'LearnRate', optimizableVariable('log', [1e-4, 1e-2]), ...
'BatchSize', optimizableVariable('discrete', [32, 64, 128]));
results = hyperparameterOptimization(@myTrainingFcn, hyperparams, ...
'MaxObjectiveEvaluations', 20);
五、高级技巧
- GPU 加速
options = trainingOptions('sgdm', ...
'ExecutionEnvironment', 'gpu');
- 模型解释
% 类激活映射(CAM)
I = imread('test_image.jpg');
[YPred, scores] = classify(net, I, 'OutputAs', 'probabilities');
cam = activation(net, I, 'fc7', 'OutputAs', 'image');
imshow(cam, 'AlphaData', cam)
title(string(YPred));
- 模型部署
% 导出为ONNX格式
exportONNXNetwork(net, 'model.onnx');
% 生成C/C++代码
codegen -config:deeplearning net -args {zeros([224 224 3 'single'])};
六、实战案例:COVID-19 肺部 CT 图像分类
% 加载数据
imds = imageDatastore('covid_ct_scans', ...
'IncludeSubfolders', true, ...
'LabelSource', 'foldernames');
% 数据增强
augmenter = imageDataAugmenter(...
'RandRotation', [-10, 10], ...
'RandXScale', [0.9, 1.1], ...
'RandYScale', [0.9, 1.1]);
augimds = augmentedImageDatastore([224 224 3], imds, ...
'DataAugmentation', augmenter);
% 加载预训练ResNet-18
net = resnet18;
lgraph = layerGraph(net);
% 修改网络
lgraph = removeLayers(lgraph, {'fc1000', 'prob', 'ClassificationLayer_Predictions'});
lgraph = addLayers(lgraph, [
globalAveragePooling2dLayer('Name', 'globalpool')
fullyConnectedLayer(2, 'Name', 'fc')
softmaxLayer('Name', 'softmax')
classificationLayer('Name', 'classoutput')
]);
lgraph = connectLayers(lgraph, 'avg_pool', 'globalpool');
% 训练选项
options = trainingOptions('sgdm', ...
'InitialLearnRate', 0.0001, ...
'MaxEpochs', 15, ...
'MiniBatchSize', 16, ...
'ValidationData', augimds, ...
'ValidationFrequency', 30, ...
'Verbose', false, ...
'Plots', 'training-progress');
% 训练网络
net = trainNetwork(augimds, lgraph, options);
% 评估模型
YPred = classify(net, augimds);
YTest = augimds.Labels;
accuracy = mean(YPred == YTest);
fprintf('CT图像分类准确率: %.2f%%\n', accuracy*100);