基于mxnet的SSD源码学习0

train.py

这个文件是整个SSD进行训练的入口,定义了训练模式、文件入口、网络结构等等。下面对该文件中的各个函数进行解读。
~python
if name=='main':
args=parse_args()
~
这个主要是接收训练时所需要定义的参数。可以通过函数parse_args()获取。
~python
def parse_args():
parser = argparse.ArgumentParser(description='Train a Single-shot detection network')
parser.add_argument('--train-path', dest='train_path', help='train record to use',
default=os.path.join(os.getcwd(), 'data', 'train.rec'), type=str)
parser.add_argument('--train-list', dest='train_list', help='train list to use',
default="", type=str)
parser.add_argument('--val-path', dest='val_path', help='validation record to use',
default=os.path.join(os.getcwd(), 'data', 'val.rec'), type=str)
parser.add_argument('--val-list', dest='val_list', help='validation list to use',
default="", type=str)
parser.add_argument('--network', dest='network', type=str, default='vgg16_reduced',
help='which network to use')
parser.add_argument('--batch-size', dest='batch_size', type=int, default=32,
help='training batch size')
parser.add_argument('--resume', dest='resume', type=int, default=-1,
help='resume training from epoch n')
parser.add_argument('--finetune', dest='finetune', type=int, default=-1,
help='finetune from epoch n, rename the model before doing this')
parser.add_argument('--pretrained', dest='pretrained', help='pretrained model prefix',
default=os.path.join(os.getcwd(), 'model', 'vgg16_reduced'), type=str)
parser.add_argument('--epoch', dest='epoch', help='epoch of pretrained model',
default=1, type=int)
parser.add_argument('--prefix', dest='prefix', help='new model prefix',
default=os.path.join(os.getcwd(), 'model', 'ssd'), type=str)
parser.add_argument('--gpus', dest='gpus', help='GPU devices to train with',
default='0', type=str)
parser.add_argument('--begin-epoch', dest='begin_epoch', help='begin epoch of training',
default=0, type=int)
parser.add_argument('--end-epoch', dest='end_epoch', help='end epoch of training',
default=240, type=int)
parser.add_argument('--frequent', dest='frequent', help='frequency of logging',
default=20, type=int)
parser.add_argument('--data-shape', dest='data_shape', type=int, default=300,
help='set image shape')
parser.add_argument('--label-width', dest='label_width', type=int, default=350,
help='force padding label width to sync across train and validation')
parser.add_argument('--lr', dest='learning_rate', type=float, default=0.002,
help='learning rate')
parser.add_argument('--momentum', dest='momentum', type=float, default=0.9,
help='momentum')
parser.add_argument('--wd', dest='weight_decay', type=float, default=0.0005,
help='weight decay')
parser.add_argument('--mean-r', dest='mean_r', type=float, default=123,
help='red mean value')
parser.add_argument('--mean-g', dest='mean_g', type=float, default=117,
help='green mean value')
parser.add_argument('--mean-b', dest='mean_b', type=float, default=104,
help='blue mean value')
parser.add_argument('--lr-steps', dest='lr_refactor_step', type=str, default='80, 160',
help='refactor learning rate at specified epochs')
parser.add_argument('--lr-factor', dest='lr_refactor_ratio', type=float, default=0.1,
help='ratio to refactor learning rate')
parser.add_argument('--freeze', dest='freeze_pattern', type=str, default="^(conv1_|conv2_).",
help='freeze layer pattern')
parser.add_argument('--log', dest='log_file', type=str, default="train.log",
help='save training log to file')
parser.add_argument('--monitor', dest='monitor', type=int, default=0,
help='log network parameters every N iters if larger than 0')
parser.add_argument('--pattern', dest='monitor_pattern', type=str, default=".
",
help='monitor parameter pattern, as regex')
parser.add_argument('--num-class', dest='num_class', type=int, default=20,
help='number of classes')
parser.add_argument('--num-example', dest='num_example', type=int, default=16551,
help='number of image examples')
parser.add_argument('--class-names', dest='class_names', type=str,
default='aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa, train, tvmonitor',
help='string of comma separated names, or text filename')
parser.add_argument('--nms', dest='nms_thresh', type=float, default=0.45,
help='non-maximum suppression threshold')
parser.add_argument('--overlap', dest='overlap_thresh', type=float, default=0.5,
help='evaluation overlap threshold')
parser.add_argument('--force', dest='force_nms', action='store_true',
help='force non-maximum suppression on different class')
parser.add_argument('--use-difficult', dest='use_difficult', action='store_true',
help='use difficult ground-truths in evaluation')
parser.add_argument('--no-voc07', dest='use_voc07_metric', action='store_false',
help='dont use PASCAL VOC 07 11-point metric')
args = parser.parse_args()
return args
~
这个函数定义了训练时,文件的入口、网络的结构、物体的种类等等。
~python
def parse_class_names(args):
""" parse # classes and class_names if applicable """
num_class = args.num_class
if len(args.class_names) > 0:
if os.path.isfile(args.class_names):
# try to open it to read class names
with open(args.class_names, 'r') as f:
class_names = [l.strip() for l in f.readlines()]
else:
class_names = [c.strip() for c in args.class_names.split(',')]
assert len(class_names) == num_class, str(len(class_names))
for name in class_names:
assert len(name) > 0
else:
class_names = None
return class_names
~
定义了从文件中读取物体分类。至此,训练网络所需的外部参数基本完成,下一步就是进入train_net来进行训练。

train_net.py

get_lr_scheduler

这个函数定义了学习率的变化。
~python
assert lr_refactor_ratio > 0
iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
if lr_refactor_ratio >= 1:
return (learning_rate, None)
~
这段话明确了学习率必须大于0,此外学习率的调整率必须小于1。
~~~python
else:
lr = learning_rate
epoch_size = num_example // batch_size
for s in iter_refactor:
if begin_epoch >= s:
lr *= lr_refactor_ratio
if lr != learning_rate:
logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
if not steps:
return (lr, None)
lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
return (lr, lr_scheduler)
~~~
当学习率的调整率小于1的时候,返回学习率。

train_net

定义了训练网络的入口。
首先是定义日志文件,并对日志文件进行相应的配置。
~python
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
if log_file:
fh = logging.FileHandler(log_file)
logger.addHandler(fh)
~
然后判断相关参数的数据类型是否满足要求。
~~~python

check args

if isinstance(data_shape, int):
    #判断data_shape 是不是int类型
    data_shape = (3, data_shape, data_shape)
assert len(data_shape) == 3 and data_shape[0] == 3
prefix += '_' + net + '_' + str(data_shape[1])

if isinstance(mean_pixels, (int, float)):
    mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
assert len(mean_pixels) == 3, "must provide all RGB mean values"

~~~
接下来就是获取相关的数据,并将训练数据赋值给train_iter。然后判断是否存在验证集,有的话就加载进来。
~~~python
train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)
#获取训练数据

if val_path:
    val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
        label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
    #获取验证数据
else:
    val_iter = None

~~~
以symbol的形式加载网络参数,并赋值给net。
~~~python

load symbol

net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
    nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)

~~~
获取需要将权值固定的网络参数。
~~~python

define layers with fixed weight/bias

if freeze_layer_pattern.strip():
    re_prog = re.compile(freeze_layer_pattern)
    fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
else:
    fixed_param_names = None

~~~
定义训练模式
~~~python

load pretrained or resume from previous state

ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
if resume > 0:
    logger.info("Resume training with {} from epoch {}"
        .format(ctx_str, resume))
    _, args, auxs = mx.model.load_checkpoint(prefix, resume)
    begin_epoch = resume
elif finetune > 0:
    logger.info("Start finetuning with {} from epoch {}"
        .format(ctx_str, finetune))
    _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
    begin_epoch = finetune
    # the prediction convolution layers name starts with relu, so it's fine
    fixed_param_names = [name for name in net.list_arguments() \
        if name.startswith('conv')]
elif pretrained:
    logger.info("Start training with {} from pretrained model {}"
        .format(ctx_str, pretrained))
    _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
    args = convert_pretrained(pretrained, args)
else:
    logger.info("Experimental: start training from scratch with {}"
        .format(ctx_str))
    args = None
    auxs = None
    fixed_param_names = None

~~~
展示固定的参数列表和初始化训练模型
~~~python

helper information

if fixed_param_names:
    logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')

# init training module
mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
                    fixed_param_names=fixed_param_names)

~~~
给定训练时需要的调整参数并进入训练函数
~~~python
# fit parameters
batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
epoch_end_callback = mx.callback.do_checkpoint(prefix)
learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
lr_refactor_ratio, num_example, batch_size, begin_epoch)
optimizer_params={'learning_rate':learning_rate,
'momentum':momentum,
'wd':weight_decay,
'lr_scheduler':lr_scheduler,
'clip_gradient':None,
'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None

# run fit net, every n epochs we run evaluation network to get mAP
if voc07_metric:
    valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
else:
    valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)

mod.fit(train_iter,
        val_iter,
        eval_metric=MultiBoxMetric(),
        validation_metric=valid_metric,
        batch_end_callback=batch_end_callback,
        epoch_end_callback=epoch_end_callback,
        optimizer='sgd',
        optimizer_params=optimizer_params,
        begin_epoch=begin_epoch,
        num_epoch=end_epoch,
        initializer=mx.init.Xavier(),
        arg_params=args,
        aux_params=auxs,
        allow_missing=True,
        monitor=monitor)

~~~

转载于:https://www.cnblogs.com/mumua/p/9006247.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值