使用fine-tun的方式训练模型的话首先需要下载相应的模型,然后按照自己的数据集修改相应的类别,最后训练。
系统: ubuntu14.04
Mxnet: 0.904
1.数据准备
train_iter = "/mxnet/tools/train-cat.rec"
val_iter = "/mxnet/tools/train-cat_test.rec"
batch_size=10
num_epoch = 40
train_dataiter = mx.io.ImageRecordIter(
path_imgrec=train_iter,
#mean_img=datadir+"/mean.bin",
rand_crop=True,
rand_mirror=True,
data_shape=(3,224,224),
batch_size=batch_size,
preprocess_threads=1)
test_dataiter = mx.io.ImageRecordIter(
path_imgrec=val_iter,
#mean_img=datadir+"/mean.bin",
rand_crop=False,
rand_mirror=False,
data_shape=(3,224,224),
batch_size=batch_size,
preprocess_threads=1)
2.加载fine-tune模型
模型可以通过Mxnet Model Zoo下载。这里下载的是vgg16
sym,arg_params,aux_params=mx.model.load_checkpoint('model/vgg16',0)
3.修改类别
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='drop7'):
"""
symbol: the pretrained network symbol
arg_params: the argument parameters of the pretrained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
all_layers = symbol.get_internals()
net = all_layers[layer_name+'_output']#不要忘了'_output',vgg16的fc8的上一层是drop7
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc8_new')#fine-tune修改修改最后一层名字
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
return (net, new_args)
4.训练模型
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
devs = [mx.gpu(i) for i in range(num_gpus)]
mod = mx.mod.Module(symbol=symbol, context=devs)
mod.fit(train, val,
num_epoch=8,
arg_params=arg_params,
aux_params=aux_params,
allow_missing=True,
batch_end_callback = mx.callback.Speedometer(batch_size, 10),
kvstore='device',
optimizer='sgd',
optimizer_params={'learning_rate':0.001},
initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
eval_metric='acc')
mod.save_checkpoint('./vggnew',num_epoch)#保存模型
metric = mx.metric.Accuracy()
return mod.score(val, metric)
num_classes = 2 #2类
batch_per_gpu = 16
num_gpus = 1
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)
b = mx.viz.plot_network(new_sym)#可视化网络结构
b.view()
batch_size = batch_per_gpu * num_gpus
#(train, val) = get_iterators(batch_size)
mod_score = fit(new_sym, new_args, aux_params, train_dataiter, test_dataiter, batch_size, num_gpus)
assert mod_score > 0.77, "Low training accuracy."
参考文献:
[1]http://mxnet.io/how_to/finetune.html