MATLAB对Squeezenet模型进行迁移学习

使用MATLAB自带的Squeezenet模型进行迁移学习,若没有安装Squeezenet模型支持工具,在命令窗口输入squeezenet,点击下载链接进行安装。
训练环境:Windows10系统,MATLAB20018b,CPU i3 3.7GHz,4GB内存。
使用squeezenet模型进行迁移学习的MATLAB代码如下:

%% 加载数据
clc;close all;clear;
% unzip('MerchData.zip');
Location = 'E:\Graduation_Project\RBC&WBC';
imds = imageDatastore(Location,...
                       'IncludeSubfolders',true,...
                       'LabelSource','foldernames');
[imdsTrain,imdsValidation] = splitEachLabel(imds,0.7);

%% 加载预训练网络
SqueezenetTrain = squeezenet;
% analyzeNetwork(SqueezenetTrain)%展示网络的层次结构和细节信息
% SqueezenetTrain.Layers(1)
inputSize = SqueezenetTrain.Layers(1).InputSize;

%% 替代最终层
if isa(SqueezenetTrain,'SeriesNetwork') 
  lgraph = layerGraph(SqueezenetTrain.Layers); 
else
  lgraph = layerGraph(SqueezenetTrain);
end 
[learnableLayer,classLayer] = findLayersToReplace(lgraph);
[learnableLayer,classLayer] 
numClasses = numel(categories(imdsTrain.Labels));

if isa(learnableLayer,'nnet.cnn.layer.FullyConnectedLayer')
    %isa(obj,'ClassName')确定类是否为指定输入的对象
    %如果obj是指定classCategory中任何类的实例则返回true否则返回false。
    newLearnableLayer = fullyConnectedLayer(numClasses, ...
        'Name','new_fc', ...
        'WeightLearnRateFactor',20, ...
        'BiasLearnRateFactor',20);
    
elseif isa(learnableLayer,'nnet.cnn.layer.Convolution2DLayer')
    newLearnableLayer = convolution2dLayer(1,numClasses, ...
        'Name','new_conv', ...
        'WeightLearnRateFactor',20, ...
        'BiasLearnRateFactor',20);
end

lgraph = replaceLayer(lgraph,learnableLayer.Name,newLearnableLayer);

newClassLayer = classificationLayer('Name','new_classoutput');
lgraph = replaceLayer(lgraph,classLayer.Name,newClassLayer);
% 
% figure('Units','normalized','Position',[0.3 0.3 0.4 0.4]);
% plot(lgraph)
% ylim([0,10])

%% 冻结初始层
layers = lgraph.Layers;
connections = lgraph.Connections;

layers(1:10) = freezeWeights(layers(1:10));
lgraph = createLgraphUsingConnections(layers,connections);

%% 训练网络
pixelRange = [-30 30];
scaleRange = [0.9 1.1];
imageAugmenter = imageDataAugmenter( ...
    'RandXReflection',true, ...
    'RandXTranslation',pixelRange, ...
    'RandYTranslation',pixelRange, ...
    'RandXScale',scaleRange, ...
    'RandYScale',scaleRange);
augimdsTrain = augmentedImageDatastore(inputSize(1:2),imdsTrain, ...
    'DataAugmentation',imageAugmenter);

augimdsValidation = augmentedImageDatastore(inputSize(1:2),imdsValidation);

options = trainingOptions('sgdm', ...
    'MiniBatchSize',10, ...
    'MaxEpochs',6, ...
    'InitialLearnRate',1e-4, ...
    'Shuffle','every-epoch', ...
    'ValidationData',augimdsValidation, ...
    'ValidationFrequency',60, ... 
    'Verbose',true, ...
    'Plots','training-progress');

SqueezenetTrain = trainNetwork(augimdsTrain,lgraph,options);
%% 验证分类图片
[YPred,probs] = classify(SqueezenetTrain,augimdsValidation);
accuracy = mean(YPred == imdsValidation.Labels)
idx = randperm(numel(imdsValidation.Files),4);

figure
for i = 1:4
    subplot(2,2,i)
    I = readimage(imdsValidation,idx(i));
    imshow(I)
    label = YPred(idx(i));
    title(string(label) + ", " + num2str(100*max(probs(idx(i),:)),3) + "%");
end

%% 保存训练好的模型
save Squeezenet_01 SqueezenetTrain;

函数findLayerToReplace

% findLayersToReplace(lgraph) finds the single classification layer and the
% preceding learnable (fully connected or convolutional) layer of the layer
% graph lgraph.
function [learnableLayer,classLayer] = findLayersToReplace(lgraph)

if ~isa(lgraph,'nnet.cnn.LayerGraph')
    error('Argument must be a LayerGraph object.')
end

% Get source, destination, and layer names.
src = string(lgraph.Connections.Source);
dst = string(lgraph.Connections.Destination);
layerNames = string({lgraph.Layers.Name}');

% Find the classification layer. The layer graph must have a single
% classification layer.
isClassificationLayer = arrayfun(@(l) ...
    (isa(l,'nnet.cnn.layer.ClassificationOutputLayer')|isa(l,'nnet.layer.ClassificationLayer')), ...
    lgraph.Layers);

if sum(isClassificationLayer) ~= 1
    error('Layer graph must have a single classification layer.')
end
classLayer = lgraph.Layers(isClassificationLayer);


% Traverse the layer graph in reverse starting from the classification
% layer. If the network branches, throw an error.
currentLayerIdx = find(isClassificationLayer);
while true
    
    if numel(currentLayerIdx) ~= 1
        error('Layer graph must have a single learnable layer preceding the classification layer.')
    end
    
    currentLayerType = class(lgraph.Layers(currentLayerIdx));
    isLearnableLayer = ismember(currentLayerType, ...
        ['nnet.cnn.layer.FullyConnectedLayer','nnet.cnn.layer.Convolution2DLayer']);
    
    if isLearnableLayer
        learnableLayer =  lgraph.Layers(currentLayerIdx);
        return
    end
    
    currentDstIdx = find(layerNames(currentLayerIdx) == dst);
    currentLayerIdx = find(src(currentDstIdx) == layerNames);
    
end

end

在这里插入图片描述
使用训练好的squeezenet模型进行图片分类测试,我训练的模型是对BYST、GRAN、HYAL、MUCS、RBC、WBC、WBCC等图像进行分类:

%% 加载模型
 clc;close all;clear;
load('-mat','E:\MATLAB_Code\Squeezenet_01');
%% 加载测试集
Location = 'E:\image_test\test_02';
imds = imageDatastore(Location,'includeSubfolders',true,'LabelSource','foldernames');
inputSize = SqueezenetTrain.Layers(1).InputSize; 
imdstest = augmentedImageDatastore(inputSize(1:2),imds);
tic;
[YPred,scores] = classify(SqueezenetTrain,imdstest);
%使用训练好的模型对测试集进行分类
disp(['分类所用时间为:',num2str(toc),'秒']);
%% 显示分类结果,绘制混淆矩阵
byst = 'BYST';
BYST = numel(YPred,YPred == byst);
disp(['BYST = ',num2str(BYST)]);
gran = 'GRAN';
GRAN = numel(YPred,YPred == gran);
disp(['GRAN = ',num2str(GRAN)]);
hyal = 'HYAL';
HYAL = numel(YPred,YPred == hyal);
disp(['HYAL = ',num2str(HYAL)]);
mucs = 'MUCS';
MUCS = numel(YPred,YPred == mucs);
disp(['MUCS = ',num2str(MUCS)]);
rbc = 'RBC';
RBC = numel(YPred,YPred == rbc);
disp(['RBC = ',num2str(RBC)]);
wbc = 'WBC';
WBC = numel(YPred,YPred == wbc);
disp(['WBC = ',num2str(WBC)]);
wbcc = 'WBCC';
WBCC = numel(YPred,YPred == wbcc);
disp(['WBCC = ',num2str(WBCC)]);
sum = numel(YPred);
disp(['sum = ',num2str(sum)]);
% 求出每个标签对应的分类数量
% numel(A)  返回数组A的数目
% numel(A,x) 返回数组A在x的条件下的数目
%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%计算精确度%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 YTest = imds.Labels;
 accuracy = mean(YPred == YTest);
 disp(['accuracy = ',num2str(accuracy)]);
 % disp(x)   变量x的值
 % num2str(x)  将数值数值转换为表示数字的字符数组


%% 绘制混淆矩阵
predictLabel = YPred;%通过训练好的模型分类后的标签
actualLabel = YTest;%原始的标签
plotconfusion(actualLabel,predictLabel,'Squeezenet');%绘制混淆矩阵
%    plotconfusion(targets,outputs);绘制混淆矩阵,使用target(true)和output(predict)标签,将标签指定为分类向量或1到N的形式
 %%
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%随机显示测试分类后的图片%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
idx = randperm(numel(imds.Files),9);
figure
for i = 1:9
    subplot(3,3,i);
    I = readimage(imds,idx(i));
    imshow(I);
    label = YPred(idx(i));
    title(string(label) + ',' + num2str(100*max(scores(idx(i),:)),3) + '%');

end
%% 保存分类后的图片
x = numel(imds.Files);
% 图片保存位置
Location_BYST = 'E:\image_classification\Squeezenet\BYST';
Location_GRAN = 'E:\image_classification\Squeezenet\GRAN';
Location_HYAL = 'E:\image_classification\Squeezenet\HYAL';
Location_MUCS = 'E:\image_classification\Squeezenet\MUCS';
Location_RBC  = 'E:\image_classification\Squeezenet\RBC';
Location_WBC  = 'E:\image_classification\Squeezenet\WBC';
Location_WBCC = 'E:\image_classification\Squeezenet\WBCC';
writePostfix = '.bmp';%图片保存后缀
for i = 1:x
    I = readimage(imds,i);
    Label = YPred(i);
    Name = YTest(i);
   switch Label
       case 'BYST'
           saveName = sprintf('%s%s%s_%d',Location_BYST,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'GRAN'
           saveName = sprintf('%s%s%s_%d',Location_GRAN,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'HYAL'
           saveName = sprintf('%s%s%s_%d',Location_HYAL,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'MUCS'
           saveName = sprintf('%s%s%s_%d',Location_MUCS,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'RBC'
           saveName = sprintf('%s%s%s_%d',Location_RBC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'WBC'
           saveName = sprintf('%s%s%s_%d',Location_WBC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
       case 'WBCC'
           saveName = sprintf('%s%s%s_%d',Location_WBCC,'\',Name,i,writePostfix);
           imwrite(I,saveName);
   end
    
    
end

程序运行结果:
在这里插入图片描述

在这里插入图片描述

  • 5
    点赞
  • 39
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
本课程适合具有一定深度学习基础,希望发展为深度学习之计算机视觉方向的算法工程师和研发人员的同学们。基于深度学习的计算机视觉是目前人工智能最活跃的领域,应用非常广泛,如人脸识别和无人驾驶中的机器视觉等。该领域的发展日新月异,网络模型和算法层出不穷。如何快速入门并达到可以从事研发的高度对新手和中级水平的学生而言面临不少的挑战。精心准备的本课程希望帮助大家尽快掌握基于深度学习的计算机视觉的基本原理、核心算法和当前的领先技术,从而有望成为深度学习之计算机视觉方向的算法工程师和研发人员。本课程系统全面地讲述基于深度学习的计算机视觉技术的原理并进行项目实践。课程涵盖计算机视觉的七大任务,包括图像分类、目标检测、图像分割(语义分割、实例分割、全景分割)、人脸识别、图像描述、图像检索、图像生成(利用生成对抗网络)。本课程注重原理和实践相结合,逐篇深入解读经典和前沿论文70余篇,图文并茂破译算法难点, 使用思维导图梳理技术要点。项目实践使用Keras框架(后端为Tensorflow),学员可快速上手。通过本课程的学习,学员可把握基于深度学习的计算机视觉的技术发展脉络,掌握相关技术原理和算法,有助于开展该领域的研究与开发实战工作。另外,深度学习之计算机视觉方向的知识结构及学习建议请参见本人CSDN博客。本课程提供课程资料的课件PPT(pdf格式)和项目实践代码,方便学员学习和复习。本课程分为上下两部分,其中上部包含课程的前五章(课程介绍、深度学习基础、图像分类、目标检测、图像分割),下部包含课程的后四章(人脸识别、图像描述、图像检索、图像生成)。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值