Load training data.
data = load('fasterRCNNVehicleTrainingData.mat');
trainingData = data.vehicleTrainingData;
trainingData.imageFilename = fullfile(toolboxdir('vision'),'visiondata', ...
trainingData.imageFilename);
Randomly shuffle data for training.
rng(0);
shuffledIdx = randperm(height(trainingData));
trainingData = trainingData(shuffledIdx,:);
Create an image datastore using the files from the table.
imds = imageDatastore(trainingData.imageFilename);
Create a box label datastore using the label columns from the table.
blds = boxLabelDatastore(trainingData(:,2:end));
Combine the datastores.
ds = combine(imds, blds);
Set up the network layers.
lgraph = layerGraph(data.detector.Network);
Configure training options.
options = trainingOptions('sgdm', ...
'MiniBatchSize', 1, ...
'InitialLearnRate', 1e-3, ...
'MaxEpochs', 7, ...
'VerboseFrequency', 200, ...
'CheckpointPath', tempdir);
Train detector. Training will take a few minutes. Adjust the NegativeOverlapRange and PositiveOverlapRange to ensure training samples tightly overlap with ground truth.
detector = trainFasterRCNNObjectDetector(ds, lgraph, options, ...
'NegativeOverlapRange',[0 0.3], ...
'PositiveOverlapRange',[0.6 1]);
*************************************************************************
Training a Faster R-CNN Object Detector for the following object classes:
* vehicle
Training on single GPU.
Initializing input data normalization.
|=============================================================================================================================================|
| Epoch | Iteration | Time Elapsed | Mini-batch | Mini-batch | Mini-batch | RPN Mini-batch | RPN Mini-batch | Base Learning |
| | | (hh:mm:ss) | Loss | Accuracy | RMSE | Accuracy | RMSE | Rate |
|=============================================================================================================================================|
| 1 | 1 | 00:00:00 | 0.8771 | 97.30% | 0.83 | 91.41% | 0.71 | 0.0010 |
| 1 | 200 | 00:01:15 | 0.5324 | 100.00% | 0.15 | 88.28% | 0.70 | 0.0010 |
| 2 | 400 | 00:02:40 | 0.4732 | 100.00% | 0.15 | 92.19% | 0.63 | 0.0010 |
| 3 | 600 | 00:04:03 | 0.4776 | 97.14% | 0.09 | 96.88% | 0.59 | 0.0010 |
| 3 | 800 | 00:05:23 | 0.5269 | 97.44% | 0.18 | 89.06% | 0.68 | 0.0010 |
| 4 | 1000 | 00:06:44 | 0.9749 | 100.00% | | 85.16% | 1.00 | 0.0010 |
| 5 | 1200 | 00:08:07 | 1.1952 | 97.62% | 0.13 | 77.34% | 1.27 | 0.0010 |
| 5 | 1400 | 00:09:24 | 0.6577 | 100.00% | | 76.38% | 0.72 | 0.0010 |
| 6 | 1600 | 00:10:46 | 0.6951 | 100.00% | | 90.62% | 0.94 | 0.0010 |
| 7 | 1800 | 00:12:08 | 0.5341 | 96.08% | 0.09 | 86.72% | 0.53 | 0.0010 |
| 7 | 2000 | 00:13:26 | 0.3333 | 100.00% | 0.12 | 94.53% | 0.61 | 0.0010 |
| 7 | 2065 | 00:13:52 | 1.0564 | 100.00% | | 71.09% | 1.23 | 0.0010 |
|=============================================================================================================================================|
Detector training complete.
*******************************************************************
Test the Faster R-CNN detector on a test image.
img = imread('highway.png');
Run the detector.
[bbox, score, label] = detect(detector,img);
Display detection results.
detectedImg = insertShape(img,'Rectangle',bbox);
figure
imshow(detectedImg)