深度学习(十四):详解Matconvnet使用imagenet模型训练自己的数据集

上节讨论过如何使一个简单的cnn网络训练mnist数据集,该节介绍复杂并且使用广泛的使用imagenet网络的预训练模型训练自己的数据集。

Ok首先是自己的数据集了。Matconvnet中训练imagenet的数据集的准备不像caffe这些工具箱弄得那么好,弄个train文件夹,test文件夹,以及两个txt索引就好了,感觉很不人性。后面我将会将其输入改为这种人性的类型输入格式。

这里是有一个网友准备的很小的图像数据库,原始链接

但是其类别索引是从0开始的,这在matlab中是不符合的,所以我将其改成从1开始的。同时添加了一个类class标签的txt,改完的

数据集下载地址

下载完打开这个文件夹看到:
这里写图片描述

其中train就是训练所用到的所有图片,test为测试所有图片,train_label为对应图片的名字以及跟随的类标签(从1开始),打开txt可以看到为:

这里写图片描述

这种格式的txt相信应该很容易从你自己的数据集中弄到。依次类推,test.txt中存放的是test文件夹所有图片的名字以及其类别。

Classind 就是每一类表示的分类的名字。
这里写图片描述

数据准备好了,放在哪呢?我们在Matconvnet的工具箱目录下新建一个文件夹为data,然后将这个数据集放进去,如下:

这里写图片描述

我们是在训练好的model上继续训练,所以需要一个model,再在这文件夹下建立一个models文件夹,然后把imagenet-vgg-f.mat放入到models里面。这里我们使用的是vgg-f的model,这个model在前两节说到了,自己去下载。

接着就是网络训练了。再建立一个文件夹train,可以编写函数了。

首先是主函数:
这里复制一下examples中的imagenet里面的一个主函数cnn_dicnn,然后修改一下里面的路径,程序为:

function [net, info] = cnn_dicnn(varargin)
%CNN_DICNN Demonstrates fine-tuning a pre-trained CNN on imagenet dataset

run(fullfile(fileparts(mfilename('fullpath')), ...
  '..', 'matconvnet', 'matlab', 'vl_setupnn.m')) ;
% 修改读入文件夹的路径
opts.dataDir = fullfile('data','image') ;
opts.expDir  = fullfile('exp', 'image') ;
% 导入预训练的model
opts.modelPath = fullfile('models','imagenet-vgg-f.mat');
[opts, varargin] = vl_argparse(opts, varargin) ;

opts.numFetchThreads = 12 ;

opts.lite = false ;
opts.imdbPath = fullfile(opts.expDir, 'imdb.mat');

opts.train = struct() ;
opts.train.gpus = [];
opts.train.batchSize = 8 ;
opts.train.numSubBatches = 4 ;
opts.train.learningRate = 1e-4 * [ones(1,10), 0.1*ones(1,5)];

opts = vl_argparse(opts, varargin) ;
if ~isfield(opts.train, 'gpus'), opts.train.gpus = []; end;

% -------------------------------------------------------------------------
%                                                             Prepare model
% -------------------------------------------------------------------------
net = load(opts.modelPath);
% 修改一下这个model
net = prepareDINet(net,opts);
% -------------------------------------------------------------------------
%                                                              Prepare data
% -------------------------------------------------------------------------
% 准备数据格式
if exist(opts.imdbPath,'file')
  imdb = load(opts.imdbPath) ;
else
  imdb = cnn_image_setup_data('dataDir', opts.dataDir, 'lite', opts.lite) ;
  mkdir(opts.expDir) ;
  save(opts.imdbPath, '-struct', 'imdb') ;
end

imdb.images.set = imdb.images.sets;

% Set the class names in the network
net.meta.classes.name = imdb.classes.name ;
net.meta.classes.description = imdb.classes.name ;

% % 求训练集的均值
imageStatsPath = fullfile(opts.expDir, 'imageStats.mat') ;
if exist(imageStatsPath)
  load(imageStatsPath, 'averageImage') ;
else
    averageImage = getImageStats(opts, net.meta, imdb) ;
    save(imageStatsPath, 'averageImage') ;
end
% % 用新的均值改变均值
net.meta.normalization.averageImage = averageImage;
% --------------------------------
评论 285
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值