MATLAB中使用AlexNet、VGG、GoogLeNet进行迁移学习

直接贴代码,具体用法见注释:

clc;clear;

net = alexnet; %加载在ImageNet上预训练的网络模型
imageInputSize = [227 227 3];
%加载图像
allImages = imageDatastore('.\data227Alexnet',...
    'IncludeSubfolders',true,...
    'LabelSource','foldernames');
    %划分训练集和验证集
 [training_set,validation_set] = splitEachLabel(allImages,0.7,'randomized');
 %由于原始网络全连接层1000个输出,显然不适用于我们的分类任务,因此在这里替换
layersTransfer = net.Layers(1:end-3);
categories(training_set.Labels)
numClasses = numel(categories(training_set.Labels));
%新的网络
layers = [
    layersTransfer
    fullyConnectedLayer(numClasses,'Name', 'fc','WeightLearnRateFactor',1,'BiasLearnRateFactor',1)
    softmaxLayer('Name', 'softmax')
    classificationLayer('Name', 'classOutput')];

lgraph = layerGraph(layers);
plot(lgraph)
%对数据集进行扩增
augmented_training_set = augmentedImageSource(imageInputSize,training_set);


opts = trainingOptions('adam', ...
    'MiniBatchSize', 32,... % mini batch size, limited by GPU RAM, default 100 on Titan, 500 on P6000
    'InitialLearnRate', 1e-4,... % fixed learning rate
    'LearnRateSchedule','piecewise',...
    'LearnRateDropFactor',0.25,...
    'LearnRateDropPeriod',10,...
    'L2Regularization', 1e-4,... constraint
    'MaxEpochs',20,..
    'ExecutionEnvironment', 'gpu',...
    'ValidationData', validation_set,...
    'ValidationFrequency',80,...
    'ValidationPatience',8,...
    'Plots', 'training-progress')

net = trainNetwork(augmented_training_set, lgraph, opts);

save Alex_Public_32.mat net

[predLabels,predScores] = classify(net, validation_set);
plotconfusion(validation_set.Labels, predLabels)
PerItemAccuracy = mean(predLabels == validation_set.Labels);
title(['overall per image accuracy ',num2str(round(100*PerItemAccuracy)),'%'])

MATLAB中训练神经网络一个非常大的优势就是训练过程中各项指标的可视化,并且最终也会生成一个混淆矩阵显示验证集的结果。

  • 2
    点赞
  • 75
    收藏
    觉得还不错? 一键收藏
  • 15
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值