fit.py是MXNet的fine-tune.py(参看博文:MXNet的fine-tune.py源码详解)中启动训练的入口,非常值得读一读源码。这个脚本是作者包装好的训练入口,最核心的还是Module类的fit方法(model.fit()就是Module类的对象在条用fit方法)。总的来讲,这个fit.py脚本包含训练的一些配置,导入模型,训练模型和保存模型这几步,接下来详细阐述。建议从最后的主函数fit()开始看起。
import mxnet as mx
import logging
import os
import time
# 这个函数主要是和学习率相关,我们在启动训练的时候一般会添加这个参数:--lr,就是学习率,
# 如果不设置的话就会采用fine-tune.py中的默认lr。lr_factor表示当你要改变lr的时候是以多大比率改变,
# 举个例子,你原来lr是0.1,设置的lr_step_epochs是2,那么当你训练到epoch==2的时候,
# 你的lr就会变成lr*lr_factor,这个lr_factor在fit.py脚本中默认设置为0.1。
# 另外你的lr_step_epochs可以有多个值,比如(2,4,6),表示当epoch达到这些值的时候都要乘以lr_factor。
def _get_lr_scheduler(args, kv):
if 'lr_factor' not in args or args.lr_factor >= 1:
return (args.lr, None)
epoch_size = args.num_examples / args.batch_size
if 'dist' in args.kv_store:
epoch_size /= kv.num_workers
begin_epoch = args.load_epoch if args.load_epoch else 0
step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
lr = args.lr
for s in step_epochs:
if begin_epoch >= s:
lr *= args.lr_factor
if lr != args.lr:
logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))
steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))
# 导入模型,其实本质还是和fine-tune.py中的导入模型一样采用model.py脚本中的load_checkpoint函数。
# 首先判断你的load_epoch参数有没有设置,没设置的话运行时候会直接报错。
def _load_model(args, rank=0):
if 'load_epoch' not in args or args.load_epoch is None:
return (None, None,