01迁移学习的简单介绍
迁移学习是一种机器学习方法,将源领域知识迁移到目标领域,能使目标领域取得更好的学习效果。简单来讲就是将训练好的典型网络模型按照自己的需求进行改进,使用相对较少的数据来进行模型的训练,并且这种方法对于硬件的要求也不是太高。基于共享参数的迁移学习是迁移学习的一种,研究源数据和目标数据空间模型之间的共同参数或者先验分布的探索,这种迁移学习的前提是学习任务中每个相关模型会共享一些相同的参数,或者先验分布。
图1 基于共享参数的迁移示意图
02 AlexNet网络原理
2012年的Image Net挑战赛上,Alex Krizhevsky首次使用了这种深度卷积网络,卷积神经网络包含多个卷积层能够自主提取输入的有效信息。AlexNet含有一个输入层,一个输出层,5个卷积层,3个池化层,2个全连接层。使用卷积和池化组合提取图片特征,使用ReLu充当激活函数,使用Dropout和数据扩充手段来抑制过拟合。
图2 AlexNet网络结构及参数示意图
03 基于AlexNet卷积神经网络实现迁移学习
基于AlexNet实现迁移学习的主要步骤有:1.加载图像,划分训练集和验证集→2.加载已经训练好的AlexNet网络→3.按需改变网络结构→4.调整数据集→5.训练网络→6观察训练结果。 MATLAB的深度学习工具箱内置了预训练好的深度神经网络模型,它的最后三层是用来对1000个类别的物体进行识别,若产生新的分类问题则需要对这三层进行调整。 取出除最后三层以外的所有层,将最后三层分别用全连接层,输出层,分类层替代。将原有预训练的网络层迁移到新的分类任务上。根据新数据设定新的全连接层参数,将全连接层的分类数设置为与新数据中的分类数相同。
% 加载图像数据
unzip('MerchData.zip');
imds = imageDatastore('MerchData', ...
'IncludeSubfolders',true, ...
'LabelSource','foldernames');
% 划分验证集和训练集8:2
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,'randomized');
% 随机显示训练集中的部分图像
numTrainImages = numel(imdsTrain.Labels);
idx = randperm(numTrainImages,16);
figure
for i = 1:16
subplot(4,4,i)
I = readimage(imdsTrain,idx(i));
imshow(I)
end
%% 步骤2:加载预训练好的网络
% 加载alexnet网络(注:该网络需要提前下载,当输入下面命令时按要求下载即可)
net = alexnet;
%% 步骤3:对网络结构进行改进
%保留AlexNet倒数第三层之前的网络
layersTransfer = net.Layers(1:end-3);
% 确定训练数据中需要分类的种类
numClasses = numel(categories(imdsTrain.Labels));
% 构建新的网络,保留AlexNet倒数第三层之前的网络,在此之后重新添加了全连接
layers = [
layersTransfer % 保留AlexNet倒数第三层之前的网络
fullyConnectedLayer(numClasses) % 将新的全连接层的输出设置为训练数据中的种类
softmaxLayer % 添加新的Softmax层
classificationLayer ]; % 添加新的分类层
%% 调整数据集
% 查看网络输入层的大小和通道数
inputSize = net.Layers(1).InputSize;
% 将批量训练图像的大小调整为与输入层的大小相同
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain);
% 将批量验证图像的大小调整为与输入层的大小相同
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
%% 对网络进行训练
% 对训练参数进行设置
options = trainingOptions('sgdm', ...
'MiniBatchSize',15, ...
'MaxEpochs',10, ...
'InitialLearnRate',0.00005, ...
'Shuffle','every-epoch', ...
'ValidationData',augimdsValidation, ...
'ValidationFrequency',3, ...
'Verbose',true, ...
'Plots','training-progress');
% 用训练图像对网络进行训练
netTransfer = trainNetwork(augimdsTrain,layers,options);
%% 验证并显示结果
% 对训练好的网络采用验证数据集进行验证
[YPred,scores] = classify(netTransfer,augimdsValidation);
% 随机显示验证效果
idx = randperm(numel(imdsValidation.Files),4);
figure
for i = 1:4
subplot(2,2,i)
I = readimage(imdsValidation,idx(i));
imshow(I)
label = YPred(idx(i));
title(string(label));
end
%% 计算分类准确率
YValidation = imdsValidation.Labels;
accuracy = mean(YPred == YValidation)
%% 创建并显示混淆矩阵
figure
confusionchart(YValidation,YPred)
04 结果
图3 随机显示测试分类后的图片
图4 训练进度
图5混淆矩阵-分类模型预测结果的情形分析表