模型蒸馏的具体意思是,从一个大型深度学习模型中抽取(transfers)小型模型来实现对技术的迁移。模型蒸馏的过程很像重新利用和重新学习知识,它能够帮助改善不同模型之间的性能,从而获得高度精准的预测结果。
代码如下:
%定义变量
net = densenet201();%使用densenet201模型来定义网络
imgDir = 'F:\测试图片路径'%存放测试图片的路径
batchSize = 32;%设置mini batch大小
%使用模型蒸馏将原始模型中掩码量进行处理
distilledNet = distill(net,'softmax');
%通过ImageDatastore函数从imgDir路径读取所有图片数据
imds = imageDatastore(imgDir);
%在densenet201模型上定义一个分类器
classifier = trainNetwork(imds, distilledNet, 'MiniBatchSize', batchSize);
%我们还可以使用其他机器学习算法(如svm、knn等)也可以对此结果进行训练
%测试:存放待测试图片的文件夹路径
testImgDir = 'F:\测试图片文件夹路径';
%从 testImgDir路径读取所有待测试照片
imdsTest = imageDatastore(testImgDir);
%使用densenet201模型的蒸馏版本进行图像分类,并输出预测标签
preLabel = classify(classifier,imdsTest);
%打印输出结果
disp('模型预测结果:');
disp(preLabel);