MATLAB对inceptionV3模型进行迁移学习

inceptionV3模型的相关论文
使用MATLAB自带的inceptionv3模型进行迁移学习,若没有安装inceptionv3模型支持工具,在命令窗口输入inceptionv3,点击下载链接进行安装。
训练环境:Windows10系统,MATLAB20018b,CPU i3 3.7GHz,4GB内存。
使用inceptionv3模型进行迁移学习的MATLAB代码如下:

%% 加载数据
clc;close all;clear;
unzip('MerchData.zip');
Location = 'E:\Graduation_Project\RBC&WBC';
imds = imageDatastore(Location ,...
                       'IncludeSubfolders',true,...
                       'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7,'randomized');
%% 加载预训练网络
net = inceptionv3;
%% 从训练有素的网络中提取图层,并绘制图层图
lgraph = layerGraph(net);
% figure('Units','normalize','Position',[0.1 0.1 0.8 0.8]);
% plot(lgraph)
% net.Layers(1)
inputSize = net.Layers(1).InputSize;
lgraph = removeLayers(lgraph,{'predictions', 'predictions_softmax', 'ClassificationLayer_predictions'});
%% 替换最终图层
numClasses = numel(categories(imdsTrain.Labels));
newLayers = [
              fullyConnectedLayer(numClasses,'Name','fc','weightLearnRateFactor',10,'BiasLearnRateFactor',10)
              softmaxLayer('Name','softmax')
              classificationLayer('Name','classoutput')];
lgraph = addLayers(lgraph,newLayers);

%%
 lgraph = connectLayers(lgraph,'avg_pool','fc');
% figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
% plot(lgraph)
% ylim([0,10])
 
 %% 冻结初始图层
 layers = lgraph.Layers;
 connections = lgraph.Connections;  
 layers(1:110) = freezeWeights(layers(1:110));
 lgraph = createLgraphUsingConnections(layers,connections);


%% 训练网络
pixelRange = [-30 30];
imageAugmenter = imageDataAugmenter(...
                                    'RandXReflection',true,...
                                    'RandXTranslation',pixelRange,...
                                    'RandYTranslation',pixelRange);
%对输入数据进行数据加强
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);
 %  自动调整验证图像大小
augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);
 %设置训练参数
options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',250, ...
    'ValidationPatience',Inf, ...
    'Verbose',true ,...
    'Plots','training-progress');
 
inceptionv3Train = trainNetwork(augimdsTrain,lgraph,options);
%% 对验证图像进行分类
[YPred,probs] = classify(inceptionv3Train,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)

%% 保存训练好的模型
 save Inceptionv3_02 inceptionv3Train;
% save  x  y;  保存训练好的模型y(注意:y为训练的模型,即y = trainNetwork()),取名为x

freezeWeights函数

function layers = freezeWeights(layers)

for ii = 1:size(layers,1)
    props = properties(layers(ii));
    for p = 1:numel(props)
        propName = props{p};
        if ~isempty(regexp(propName,'LearnRateFactor$','once'))
            layers(ii).(propName) = 0;
        end
    end
end
end

createLgraphUsingConnections函数

function lgraph = createLgraphUsingConnections(layers,connections)

lgraph = layerGraph();

for i = 1:numel(layers)
    lgraph = addLayers(lgraph,layers(i));
end
    

for c = 1:size(connections,1)
    lgraph = connectLayers(lgraph,connections.Source{c},connections.Destination{c});
end
end

findLayersToReplace函数

% findLayersToReplace(lgraph) finds the single classification layer and the
% preceding learnable (fully connected or convolutional) layer of the layer
% graph lgraph.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)

if ~isa(lgraph,'nnet.cnn.LayerGraph')
    error('Argument must be a LayerGraph object.')
end

% Get source, destination, and layer names.
src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');

% Find the classification layer. The layer graph must have a single
% classification layer.
isClassificationLayer = arrayfun(@(l) ...
    (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
    lgraph.Layers);

if sum(isClassificationLayer) ~= 1
    error('Layer graph must have a single classification layer.')
end
classLayer = lgraph.Layers(isClassificationLayer);


% Traverse the layer graph in reverse starting from the classification
% layer. If the network branches, throw an error.
currentLayerIdx = find(isClassificationLayer);
while true
    
    if numel(currentLayerIdx) ~= 1
        error('Layer graph must have a single learnable layer preceding the classification layer.')
    end
    

训练过程
在这里插入图片描述
使用训练好的模型进行图像分类
我这里训练的模型是对细胞显微图像进行分类,包括BYST,GRAN,HYAL,MUCS,RBC,WBC,WBCC七种细胞。
当softmax层输出该图片的概率小于0.5时分类为OTHERS

%% 加载模型
% clc;close all;clear;
load('-mat','E:\MATLAB_Code\Inceptionv3_01');
%% 加载测试集
Location = 'E:\image_test\test_02';
imds = imageDatastore(Location,'includeSubfolders',true,'LabelSource','foldernames');
inputSize = inceptionv3Train.Layers(1).InputSize; 
imdstest = augmentedImageDatastore(inputSize(1:2),imds);
tic;
[YPred,scores] = classify(inceptionv3Train,imdstest);
%使用训练好的模型对测试集进行分类
disp(['分类所用时间为:',num2str(toc),'秒']);
sum = numel(YPred);
% 将softmax输出概率小于0.5的分类为OTHERS
for i = 1:sum
    p = max(scores(i,:));
    if  p < 0.5
        others = 'OTHERS';
        YPred(i) = others;
    end
end
%% 显示分类结果
byst = 'BYST';
BYST = numel(YPred,YPred == byst);
disp(['BYST = ',num2str(BYST)]);
gran = 'GRAN';
GRAN = numel(YPred,YPred == gran);
disp(['GRAN = ',num2str(GRAN)]);
hyal = 'HYAL';
HYAL = numel(YPred,YPred == hyal);
disp(['HYAL = ',num2str(HYAL)]);
mucs = 'MUCS';
MUCS = numel(YPred,YPred == mucs);
disp(['MUCS = ',num2str(MUCS)]);
rbc = 'RBC';
RBC = numel(YPred,YPred == rbc);
disp(['RBC = ',num2str(RBC)]);
wbc = 'WBC';
WBC = numel(YPred,YPred == wbc);
disp(['WBC = ',num2str(WBC)]);
wbcc = 'WBCC';
WBCC = numel(YPred,YPred == wbcc);
disp(['WBCC = ',num2str(WBCC)]);
others = 'OTHERS';
OTHERS = numel(YPred,YPred == others);
disp(['OTHERS = ',num2str(OTHERS)]);
sum = numel(YPred);
disp(['sum = ',num2str(sum)]);
%求出每个标签对应的分类数量
% numel(A)  返回数组A的数目
% numel(A,x) 返回数组A在x的条件下的数目
%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%计算精确度%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 YTest = imds.Labels;
 accuracy = mean(YPred == YTest);
 disp(['accuracy = ',num2str(accuracy)]);
 % disp(x)   变量x的值
 % num2str(x)  将数值数值转换为表示数字的字符数组

%% 绘制混淆矩阵
predictLabel = YPred;%通过训练好的模型分类后的标签
actualLabel = YTest;%原始的标签
plotconfusion(actualLabel,predictLabel,'Inceptionv3');%绘制混淆矩阵
%    plotconfusion(targets,outputs);绘制混淆矩阵,使用target(true)和output(predict)标签,将标签指定为分类向量或1到N的形式
 %%
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%随机显示测试分类后的图片%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
idx = randperm(numel(imds.Files),16);
figure
for i = 1:16
    subplot(4,4,i);
    I = readimage(imds,idx(i));
    imshow(I);
    label = YPred(idx(i));
    title(string(label) + ',' + num2str(100*max(scores(idx(i),:)),3) + '%');

end
% %% 保存分类后的图片
% x = numel(imds.Files);
% % 图片保存位置
% Location_BYST = 'E:\image_classification\Inceptionv3\BYST';
% Location_GRAN = 'E:\image_classification\Inceptionv3\GRAN';
% Location_HYAL = 'E:\image_classification\Inceptionv3\HYAL';
% Location_MUCS = 'E:\image_classification\Inceptionv3\MUCS';
% Location_RBC  = 'E:\image_classification\Inceptionv3\RBC';
% Location_WBC  = 'E:\image_classification\Inceptionv3\WBC';
% Location_WBCC = 'E:\image_classification\Inceptionv3\WBCC';
% Location_OTHERS = 'E:\image_classification\Inceptionv3\OTHERS';
% writePostfix = '.bmp';%图片保存后缀
% for i = 1:x
%     I = readimage(imds,i);
%     Label = YPred(i);
%     Name = YTest(i);
%    switch Label
%        case 'BYST'
%            saveName = sprintf('%s%s%s_%d',Location_BYST,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'GRAN'
%            saveName = sprintf('%s%s%s_%d',Location_GRAN,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'HYAL'
%            saveName = sprintf('%s%s%s_%d',Location_HYAL,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'MUCS'
%            saveName = sprintf('%s%s%s_%d',Location_MUCS,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'RBC'
%            saveName = sprintf('%s%s%s_%d',Location_RBC,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'WBC'
%            saveName = sprintf('%s%s%s_%d',Location_WBC,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'WBCC'
%            saveName = sprintf('%s%s%s_%d',Location_WBCC,'\',Name,i,writePostfix);
%            imwrite(I,saveName);
%        case 'OTHERS'
%            saveName = sprintf('%s%s%s_%d',Location_OTHERS,'\'.Name,i,writePostfix);
%            imwrite(I,saveName);
%    end
%     
%     
% end

实验结果
在这里插入图片描述

  • 5
    点赞
  • 37
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值