SSD算法代码介绍(二):训练算法整体架构

上一篇:SSD算法代码介绍(一):训练参数配置
主要介绍了训练模型的一些参数配置信息,可以看出在训练脚本train.py中主要是调用train_net.py脚本中的train_net函数进行训练的,因此这一篇博客介绍train_net.py脚本的内容。

train_net.py这个脚本一共包含convert_pretrained,get_lr_scheduler,train_net三个函数,其中最重要的是train_net函数,这个函数也是train.py脚本训练模型时候调用的函数,建议从train_net函数开始看起。

import tools.find_mxnet
import mxnet as mx
import logging
import sys
import os
import importlib
import re
# 导入生成模型可用的数据格式的类,是在dataset文件夹下的iterator.py脚本中实现的,
# 一般采用这种导入脚本中类的方式需要在dataset文件夹下写一个空的__init__.py脚本才能导入
from dataset.iterator import DetRecordIter 
from train.metric import MultiBoxMetric # 导入训练时候的评价标准类
# 导入测试时候的评价标准类,这里VOC07MApMetric类继承了MApMetric类,主要内容在MApMetric类中
from evaluate.eval_metric import MApMetric, VOC07MApMetric 
from config.config import cfg
from symbol.symbol_factory import get_symbol_train # get_symbol_train函数来导入symbol

def convert_pretrained(name, args):
    """
    Special operations need to be made due to name inconsistance, etc

    Parameters:
    ---------
    name : str
        pretrained model name
    args : dict
        loaded arguments

    Returns:
    ---------
    processed arguments as dict
    """
    return args


# get_lr_scheduler函数就是设计你的学习率变化策略,函数的几个输入的意思在这里都介绍得很清楚了,
# lr_refactor_step可以是3或6这样的单独数字,也可以是3,6,9这样用逗号间隔的数字,表示到第3,6,9个epoch的时候就要改变学习率
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
                     num_example, batch_size, begin_epoch):
    """
    Compute learning rate and refactor scheduler

    Parameters:
    ---------
    learning_rate : float
        original learning rate
    lr_refactor_step : comma separated str
        epochs to change learning rate
    lr_refactor_ratio : float
        lr *= ratio at certain steps
    num_example : int
        number of training images, used to estimate the iterations given epochs
    batch_size : int
        training batch size
    begin_epoch : int
        starting epoch

    Returns:
    ---------
    (learning_rate, mx.lr_scheduler) as tuple
    """
    assert lr_refactor_ratio > 0
    iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
    # 学习率的改变一般都是越来越小,不接受学习率越来越大这种策略,在这种情况下采用学习率不变的策略
    if lr_refactor_ratio >= 1: 
        return (learning_rate, None)
    else:
        lr = learning_rate
        epoch_size = num_example // batch_size # 表示每个epoch最少包含多少个batch
# 这个for循环的内容主要是解决当你设置的begin_epoch要大于你的iter_refactor的某些值的时候,
# 会按照lr_refactor_ratio改变你的初始学习率,也就是说这个改变是还没开始训练的时候就做的。
        for s in iter_refactor: 
            if begin_epoch >= s:
                lr *= lr_refactor_ratio
# 如果有上面这个学习率的改变,那么打印出改变信息&
  • 14
    点赞
  • 54
    收藏
    觉得还不错? 一键收藏
  • 35
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 35
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值