Faster R-CNN深度学习进行目标检测
版本2019a
% https://ww2.mathworks.cn/help/deeplearning/ug/object-detection-using-faster-r-cnn-deep-learning.html
%此示例说明如何训练 Faster R-CNN(区域卷积神经网络)目标检测器
%-------------下载预训练的检测器
doTrainingAndEval = true; %要训练检测器则改为true
if ~doTrainingAndEval && ~exist('fasterRCNNResNet50EndToEndVehicleExample.mat','file')
disp('Downloading pretrained detector (118 MB)...');
pretrainedURL = 'https://www.mathworks.com/supportfiles/vision/data/fasterRCNNResNet50EndToEndVehicleExample.mat';
websave('fasterRCNNResNet50EndToEndVehicleExample.mat',pretrainedURL);
end
%------------加载数据集---已标记好目标
%unzip('D:\2ma\examples\deeplearning_shared\vehicleDatasetImages.zip')
%---------unzip vehicleDatasetImages.zip
data = load('vehicleDatasetGroundTruth.mat');%数据集
vehicleDataset = data.vehicleDataset;%
% Add the fullpath to the local vehicle data folder.添加完整路径
vehicleDataset.imageFilename = fullfile(pwd, vehicleDataset.imageFilename);%pwa表示当前路径
%------车辆数据存储在一个包含两列的表中,其中第一列包含图像文件路径,第二列包含车辆边界框。
% Display first few rows of the data set.显示前几行数据形式
vehicleDataset(1:4,:);
%-------Read one of the images.读取其中一张图并显示
I = imread(vehicleDataset.imageFilename{10});
figure
subplot(211)
imshow(I)
title('原始图片')
%----------加入边界框标签。
I2 = insertShape(I, 'Rectangle', vehicleDataset.vehicle{10});%在图像中插入形状
subplot(212)
imshow(I2)
title('原始图片加入边框数据')
% % Resize and display image.重新定义图片大小
% I3 = imresize(I,3);%大小调整因子,指定为正数。如果 scale 小于 1,则输出图像小于输入图像。如果 scale 大于 1,则输出图像大于输入图像
% figure
% imshow(I3)
% title('调整原始图片大小')
%将数据集拆分为用于训练检测器的训练集
%和用于评估检测器的测试集。选择 60% 的数据进行训练。其余用于评估。
% Set random seed to ensure example training reproducibility.
rng(0);
% 随机进行两种数据集分类
shuffledIdx = randperm(height(vehicleDataset));%数据集的序号随机排列
idx = floor(0.6 * height(vehicleDataset));%0.6的数据集有多少
trainingData = vehicleDataset(shuffledIdx(1:idx),:);%训练集
testData = vehicleDataset(shuffledIdx(idx+1:end),:);%测试集
%-----------训练网络-网络已训练好的
%--trainFasterRCNNObjectDetector 分四步训练检测器。
%前两步训练 Faster R-CNN 中使用的区域提议和检测网络。
%最后两个步骤结合了前两个步骤中的网络,从而创建了一个用于检测的网络。
%使用 trainingOptions 为所有步骤指定网络训练选项。
% Options for step 1.
options = trainingOptions('sgdm', ...
'MaxEpochs', 5, ... % 用于训练的最大时期数
'MiniBatchSize', 1, ...% 用于每次训练迭代的小批量大小
'InitialLearnRate', 1e-3, ...
'CheckpointPath', tempdir);
%MiniBatchSize' 属性设置为 1,因为车辆数据集具有不同大小的图像。
%这可以防止它们被批处理在一起进行处理。
%如果训练图像大小相同,则选择大于 1 的 MiniBatchSize 以减少训练时间。
%'CheckpointPath' 属性设置为所有训练选项的临时位置,保护防止停电等
% 训练网络-不训练,已经训练好了
if doTrainingAndEval
% Train Faster R-CNN detector.
% * Use 'resnet50' as the feature extraction network.
% * Adjust the NegativeOverlapRange and PositiveOverlapRange to ensure
% training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(trainingData, 'resnet50', options, ...
'NegativeOverlapRange', [0 0.3], ...
'PositiveOverlapRange', [0.6 1]);
else
% 使用预训练的 ResNet-50 进行特征提取
% Load pretrained detector for the example.
pretrained = load('fasterRCNNResNet50VehicleExample.mat');%下载已训练号的网络模型
detector = pretrained.detector;%参数
end
% Note: This example verified on an Nvidia(TM) Titan X with 12 GB of GPU
% memory. Training this network took approximately 10 minutes using this setup.
% Training time varies depending on the hardware you use.
% % 随机测试集图片原来的图片
% I4 = imread(testData.imageFilename{7});%读入图片2
% figure
% imshow(I4)
% title('测试集的图片')
% I5 = insertShape(I4, 'Rectangle', testData.vehicle{7});%载入测试集图片的边界数据
% figure
% imshow(I5)
% title('测试集的图片及边框')
%----运行探测器-测试集的图片
I6 = imread(testData.imageFilename{10});%读入图片
figure
subplot(221)
imshow(I6)
title('输入探测器的原始图片')
[bboxes,scores] = detect(detector,I6);%得到测试图片的边框数据
%bboxes为边框数据,scores为得分
% Annotate detections in the image.
I7 = insertObjectAnnotation(I6,'rectangle',bboxes,scores);%探测器的结果
subplot(222)
imshow(I7)
title('模型结果')
I8 = imread(testData.imageFilename{10});%读入图片2
I8 = insertShape(I8, 'Rectangle', testData.vehicle{10});%载入测试集图片的边界数据
subplot(223)
imshow(I8)
title('人工标签结果')
问题:
resnet50 requires the Deep Learning Toolbox Model for ResNet-50 Network support package. To install this
support package, use the Add-On Explorer.
输入 应与以下值之一匹配:
‘alexnet’, ‘vgg16’, ‘vgg19’, ‘resnet18’, ‘resnet50’, ‘resnet101’, ‘googlenet’, ‘inceptionv3’,
‘inceptionresnetv2’, ‘squeezenet’, ‘mobilenetv2’
解决
还在学习深度学习中