上节讨论过如何使一个简单的cnn网络训练mnist数据集,该节介绍复杂并且使用广泛的使用imagenet网络的预训练模型训练自己的数据集。
Ok首先是自己的数据集了。Matconvnet中训练imagenet的数据集的准备不像caffe这些工具箱弄得那么好,弄个train文件夹,test文件夹,以及两个txt索引就好了,感觉很不人性。后面我将会将其输入改为这种人性的类型输入格式。
这里是有一个网友准备的很小的图像数据库,原始链接
但是其类别索引是从0开始的,这在matlab中是不符合的,所以我将其改成从1开始的。同时添加了一个类class标签的txt,改完的
下载完打开这个文件夹看到:
其中train就是训练所用到的所有图片,test为测试所有图片,train_label为对应图片的名字以及跟随的类标签(从1开始),打开txt可以看到为:
这种格式的txt相信应该很容易从你自己的数据集中弄到。依次类推,test.txt中存放的是test文件夹所有图片的名字以及其类别。
Classind 就是每一类表示的分类的名字。
数据准备好了,放在哪呢?我们在Matconvnet的工具箱目录下新建一个文件夹为data,然后将这个数据集放进去,如下:
我们是在训练好的model上继续训练,所以需要一个model,再在这文件夹下建立一个models文件夹,然后把imagenet-vgg-f.mat放入到models里面。这里我们使用的是vgg-f的model,这个model在前两节说到了,自己去下载。
接着就是网络训练了。再建立一个文件夹train,可以编写函数了。
首先是主函数:
这里复制一下examples中的imagenet里面的一个主函数cnn_dicnn,然后修改一下里面的路径,程序为:
function [net, info] = cnn_dicnn(varargin)
%CNN_DICNN Demonstrates fine-tuning a pre-trained CNN on imagenet dataset
run(fullfile(fileparts(mfilename('fullpath')), ...
'..', 'matconvnet', 'matlab', 'vl_setupnn.m')) ;
% 修改读入文件夹的路径
opts.dataDir = fullfile('data','image') ;
opts.expDir = fullfile('exp', 'image') ;
% 导入预训练的model
opts.modelPath = fullfile('models','imagenet-vgg-f.mat');
[opts, varargin] = vl_argparse(opts, varargin) ;
opts.numFetchThreads = 12 ;
opts.lite = false ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');
opts.train = struct() ;
opts.train.gpus = [];
opts.train.batchSize = 8 ;
opts.train.numSubBatches = 4 ;
opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)];
opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end;
% -------------------------------------------------------------------------
% Prepare model
% -------------------------------------------------------------------------
net = load(opts.modelPath);
% 修改一下这个model
net = prepareDINet(net,opts);
% -------------------------------------------------------------------------
% Prepare data
% -------------------------------------------------------------------------
% 准备数据格式
if exist(opts.imdbPath,'file')
imdb = load(opts.imdbPath) ;
else
imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ;
mkdir(opts.expDir) ;
save(opts.imdbPath, '-struct', 'imdb') ;
end
imdb.images.set = imdb.images.sets;
% Set the class names in the network
net.meta.classes.name = imdb.classes.name ;
net.meta.classes.description = imdb.classes.name ;
% % 求训练集的均值
imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ;
if exist(imageStatsPath)
load(imageStatsPath, 'averageImage') ;
else
averageImage = getImageStats(opts, net.meta, imdb) ;
save(imageStatsPath, 'averageImage') ;
end
% % 用新的均值改变均值
net.meta.normalization.averageImage = averageImage;
% --------------------------------