Mxnet图片分类(3)fine-tune

   使用fine-tun的方式训练模型的话首先需要下载相应的模型,然后按照自己的数据集修改相应的类别,最后训练。

系统: ubuntu14.04
Mxnet: 0.904

1.数据准备

train_iter = "/mxnet/tools/train-cat.rec"

val_iter = "/mxnet/tools/train-cat_test.rec"


batch_size=10
num_epoch = 40
train_dataiter = mx.io.ImageRecordIter(
            path_imgrec=train_iter,
            #mean_img=datadir+"/mean.bin",
            rand_crop=True,
            rand_mirror=True,
            data_shape=(3,224,224),
            batch_size=batch_size,
            preprocess_threads=1)
test_dataiter = mx.io.ImageRecordIter(
            path_imgrec=val_iter,
            #mean_img=datadir+"/mean.bin",
            rand_crop=False,
            rand_mirror=False,
            data_shape=(3,224,224),
            batch_size=batch_size,
            preprocess_threads=1)

2.加载fine-tune模型

模型可以通过Mxnet Model Zoo下载。这里下载的是vgg16

sym,arg_params,aux_params=mx.model.load_checkpoint('model/vgg16',0)

3.修改类别

def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='drop7'):
    """
    symbol: the pretrained network symbol
    arg_params: the argument parameters of the pretrained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()
    net = all_layers[layer_name+'_output']#不要忘了'_output',vgg16的fc8的上一层是drop7
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc8_new')#fine-tune修改修改最后一层名字
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

4.训练模型

import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
    devs = [mx.gpu(i) for i in range(num_gpus)]
    mod = mx.mod.Module(symbol=symbol, context=devs)
    mod.fit(train, val,
        num_epoch=8,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate':0.001},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    mod.save_checkpoint('./vggnew',num_epoch)#保存模型
    metric = mx.metric.Accuracy()
    return mod.score(val, metric)
num_classes = 2 #2类
batch_per_gpu = 16
num_gpus = 1

(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)
b = mx.viz.plot_network(new_sym)#可视化网络结构
b.view()
batch_size = batch_per_gpu * num_gpus
#(train, val) = get_iterators(batch_size)
mod_score = fit(new_sym, new_args, aux_params, train_dataiter, test_dataiter, batch_size, num_gpus)
assert mod_score > 0.77, "Low training accuracy."

这里写图片描述

参考文献:

[1]http://mxnet.io/how_to/finetune.html

环境的安装和数据集的制作可以参考

  1. Mxnet—faster-rcnn环境安装
  2. Mxnet图片分类(1)准备数据集

测试可以参考:

  1. Mxnet图片分类(4)利用训练好的模型进行测试
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值