上一篇:SSD算法代码介绍(一):训练参数配置
主要介绍了训练模型的一些参数配置信息,可以看出在训练脚本train.py中主要是调用train_net.py脚本中的train_net函数进行训练的,因此这一篇博客介绍train_net.py脚本的内容。
train_net.py这个脚本一共包含convert_pretrained,get_lr_scheduler,train_net三个函数,其中最重要的是train_net函数,这个函数也是train.py脚本训练模型时候调用的函数,建议从train_net函数开始看起。
import tools.find_mxnet
import mxnet as mx
import logging
import sys
import os
import importlib
import re
# 导入生成模型可用的数据格式的类,是在dataset文件夹下的iterator.py脚本中实现的,
# 一般采用这种导入脚本中类的方式需要在dataset文件夹下写一个空的__init__.py脚本才能导入
from dataset.iterator import DetRecordIter
from train.metric import MultiBoxMetric # 导入训练时候的评价标准类
# 导入测试时候的评价标准类,这里VOC07MApMetric类继承了MApMetric类,主要内容在MApMetric类中
from evaluate.eval_metric import MApMetric, VOC07MApMetric
from config.config import cfg
from symbol.symbol_factory import get_symbol_train # get_symbol_train函数来导入symbol
def convert_pretrained(name, args):
"""
Special operations need to be made due to name inconsistance, etc
Parameters:
---------
name : str
pretrained model name
args : dict
loaded arguments
Returns:
---------
processed arguments as dict
"""
return args
# get_lr_scheduler函数就是设计你的学习率变化策略,函数的几个输入的意思在这里都介绍得很清楚了,
# lr_refactor_step可以是3或6这样的单独数字,也可以是3,6,9这样用逗号间隔的数字,表示到第3,6,9个epoch的时候就要改变学习率
def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
num_example, batch_size, begin_epoch):
"""
Compute learning rate and refactor scheduler
Parameters:
---------
learning_rate : float
original learning rate
lr_refactor_step : comma separated str
epochs to change learning rate
lr_refactor_ratio : float
lr *= ratio at certain steps
num_example : int
number of training images, used to estimate the iterations given epochs
batch_size : int
training batch size
begin_epoch : int
starting epoch
Returns:
---------
(learning_rate, mx.lr_scheduler) as tuple
"""
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
# 学习率的改变一般都是越来越小,不接受学习率越来越大这种策略,在这种情况下采用学习率不变的策略
if lr_refactor_ratio >= 1:
return (learning_rate, None)
else:
lr = learning_rate
epoch_size = num_example // batch_size # 表示每个epoch最少包含多少个batch
# 这个for循环的内容主要是解决当你设置的begin_epoch要大于你的iter_refactor的某些值的时候,
# 会按照lr_refactor_ratio改变你的初始学习率,也就是说这个改变是还没开始训练的时候就做的。
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
# 如果有上面这个学习率的改变,那么打印出改变信息&