MXNet的训练入口:fit.py源码详解

本文深入探讨MXNet中的fit.py源码,它是fine-tune.py的训练启动点,主要通过Module类的fit方法进行模型训练。fit.py涉及训练配置、模型导入、训练过程及模型保存等关键步骤,适合从主函数fit()开始阅读。
摘要由CSDN通过智能技术生成

fit.py是MXNet的fine-tune.py(参看博文:MXNet的fine-tune.py源码详解)中启动训练的入口,非常值得读一读源码。这个脚本是作者包装好的训练入口,最核心的还是Module类的fit方法(model.fit()就是Module类的对象在条用fit方法)。总的来讲,这个fit.py脚本包含训练的一些配置,导入模型,训练模型和保存模型这几步,接下来详细阐述。建议从最后的主函数fit()开始看起。

import mxnet as mx
import logging
import os
import time


# 这个函数主要是和学习率相关,我们在启动训练的时候一般会添加这个参数:--lr,就是学习率,
# 如果不设置的话就会采用fine-tune.py中的默认lr。lr_factor表示当你要改变lr的时候是以多大比率改变,
# 举个例子,你原来lr是0.1,设置的lr_step_epochs是2,那么当你训练到epoch==2的时候,
# 你的lr就会变成lr*lr_factor,这个lr_factor在fit.py脚本中默认设置为0.1。
# 另外你的lr_step_epochs可以有多个值,比如(2,4,6),表示当epoch达到这些值的时候都要乘以lr_factor。
def _get_lr_scheduler(args, kv):
    if 'lr_factor' not in args or args.lr_factor >= 1:
        return (args.lr, None)
    epoch_size = args.num_examples / args.batch_size
    if 'dist' in args.kv_store:
        epoch_size /= kv.num_workers
    begin_epoch = args.load_epoch if args.load_epoch else 0
    step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
    lr = args.lr
    for s in step_epochs:
        if begin_epoch >= s:
            lr *= args.lr_factor
    if lr != args.lr:
        logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))

    steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]
    return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))


# 导入模型,其实本质还是和fine-tune.py中的导入模型一样采用model.py脚本中的load_checkpoint函数。
# 首先判断你的load_epoch参数有没有设置,没设置的话运行时候会直接报错。
def _load_model(args, rank=0):
    if 'load_epoch' not in args or args.load_epoch is None:
        return (None, None, 
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值