在MXNet框架下,如果要在一个预训练的模型上用你的数据fine-tune一个模型(或者叫迁移学习,即你的模型的参数的初始化不再是随机初始化,而是用别人的在大数据集上训练过的模型的参数来初始化你的模型参数),可以采用MXNet项目自带的fine-tune.py脚本,路径是~/mxnet/example/image-classification/fine-tune.py,这里的mxnet就是你从mxnet官方git上clone下来的项目名称,git地址:https://github.com/dmlc/mxnet。
接下来详细fine-tune.py的内容,总的流程就是:解析参数,导入预训练的模型,修改预训练模型的最后分类层,开始训练。
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet
from common import data, fit, modelzoo
import mxnet as mx
'''
这个函数的作用就是从原来导入的一个symbol(就是你导入的这个文件:###-symbol.json),
生成一个新的symbol:net和arguments:new_args,new_args是需要训练的参数名
'''
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
"""
symbol: the pre-trained network symbol
arg_params: the argument parameters of the pre-trained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
all_layers = sym.get_internals()
net = all_layers[layer_name+'_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
return (net, new_args)
if __name__ == "__main__":
# parse args
# 在这个参数解析中,主要采用~/mxnet/example/image-classification/common目录下的fit.py中的add_fit_args()函数,
# data.py中的add_data_args()函数和add_data_aug_args()函数。这三个函数都是和参数配置相关。
parser = argparse.ArgumentParser(description="fine-tune a dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
train = fit.add_fit_args(parser)
data.add_data_args(parser)
aug = data.add_data_aug_args(parser)
parser.add_argument('--pretrained-model', type=str,
help='the pre-trained model')
parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
help='the name of the layer before the last fullc layer')
# use less augmentations for fine-tune
data.set_data_aug_level(parser, 1)
# use a small learning rate and less regularizations
parser.set_defaults(image_shape='3,224,224', num_epochs=30,
lr=