py-faster-rcnn源码解读系列(三)——train.py

本文详细解读py-faster-rcnn项目的train.py源码,包括snapshot、train_model、get_training_roidb和filter_roidb四个关键部分。snapshot实现了自定义快照功能;train_model是训练主流程,负责打印训练信息并控制快照保存;get_training_roidb处理数据集,包括水平翻转图像以减少过拟合;filter_roidb通过is_valid函数筛选有效样本,确保至少包含一个前景或背景框。
摘要由CSDN通过智能技术生成

这是一个简单的solver包装类,主要是为了实现自己的snapshot,值得一提的地方不是太多,主要是为了读者从头到尾的训练过程理解更加连贯,所以为此文单独开一节源码分析。

class SolverWrapper(object):
"""A simple wrapper around Caffe's solver.
This wrapper gives us control over he snapshotting process, which we
use to unnormalize the learned bounding-box regression weights.
"""

def __init__(self, solver_prototxt, roidb, output_dir,
pretrained_model=None):
    """Initialize the SolverWrapper."""
    self.output_dir = output_dir


    if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
    cfg.TRAIN.BBOX_NORMALIZE_TARGETS):

    # RPN can only use precomputed normalization because there are no
    # fixed statistics to compute a priori
    assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

    if cfg.TRAIN.BBOX_REG:
    print 'Computing bounding-box regression targets...'
    self.bbox_means, self.bbox_stds = \
    rdl_roidb.add_bbox_regression_targets(roidb)
    print 'done'

    self.solver = caffe.SGDSolver(solver_prototxt)
    if pretrained_model is not None:
    print ('Loading pretrained model '
    'weights from {:s}').format(pretrained_model)
    self.solver.net.copy_from(pretrained_model)

    self.solver_param = caffe_pb2.SolverParameter()
 
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值