MXNet的预训练:fine-tune.py源码详解

在MXNet框架下,如果要在一个预训练的模型上用你的数据fine-tune一个模型(或者叫迁移学习,即你的模型的参数的初始化不再是随机初始化,而是用别人的在大数据集上训练过的模型的参数来初始化你的模型参数),可以采用MXNet项目自带的fine-tune.py脚本,路径是~/mxnet/example/image-classification/fine-tune.py,这里的mxnet就是你从mxnet官方git上clone下来的项目名称,git地址:https://github.com/dmlc/mxnet

接下来详细fine-tune.py的内容,总的流程就是:解析参数,导入预训练的模型,修改预训练模型的最后分类层,开始训练。

import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet
from common import data, fit, modelzoo
import mxnet as mx

'''
这个函数的作用就是从原来导入的一个symbol(就是你导入的这个文件:###-symbol.json),
生成一个新的symbol:net和arguments:new_args,new_args是需要训练的参数名
'''
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name):
    """
    symbol: the pre-trained network symbol
    arg_params: the argument parameters of the pre-trained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = sym.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
    return (net, new_args)

if __name__ == "__main__":
    # parse args
   # 在这个参数解析中,主要采用~/mxnet/example/image-classification/common目录下的fit.py中的add_fit_args()函数,
   # data.py中的add_data_args()函数和add_data_aug_args()函数。这三个函数都是和参数配置相关。
    parser = argparse.ArgumentParser(description="fine-tune a dataset",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    train = fit.add_fit_args(parser)
    data.add_data_args(parser)
    aug = data.add_data_aug_args(parser)
    parser.add_argument('--pretrained-model', type=str,
                        help='the pre-trained model')
    parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
                        help='the name of the layer before the last fullc layer')
    # use less augmentations for fine-tune
    data.set_data_aug_level(parser, 1)
    # use a small learning rate and less regularizations
    parser.set_defaults(image_shape='3,224,224', num_epochs=30,
                        lr=
  • 5
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值