从Faster RCNN开始学python(六)

本文详细解析了Faster RCNN训练网络的代码,包括filter_roidb函数如何过滤无效图片,SolverWrapper类的初始化过程,以及train_model函数中的迭代和snapshot函数的工作原理。最后讨论了模型保存和删除的过程。
摘要由CSDN通过智能技术生成

在这篇中,我们主要来解析下训练网络的代码:
代码地址为:py-faster-rcnn\lib\fast_rcnn\train.py
我们根据顺序依次贴上代码:

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths

在train_net代码中首先调用了子程序:filter_roidb(),我们进入到该程序中查看代码功能:

def filter_roidb(roidb):
“”“Remove roidb entries that have no usable RoIs.”“”

def is_valid(entry):
    # Valid images have:
    #   (1) At least one foreground RoI OR
    #   (2) At least one background RoI
    overlaps = entry['max_overlaps']
    # find boxes with sufficient overlap
    fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
    # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
    bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                       (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
    # image is only valid if such boxes exist
    valid = len(fg_inds) > 0 or len(bg_inds) > 0
    return valid

num = len(roidb)
filtered_roidb = [entry for entry in roidb if is_valid(entry)]
num_after = len(filtered_roidb)
print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                   num, num_after)
return filtered_roidb

首先我们来再次说明下roidb是训练数据的各种信息。
在for函数中先是讲roidb中的数据依次提取,导入到is_valid()函数中。

def is_valid(entry):
    # Valid images have:
    #   (1) At least one foreground RoI OR
    #   (2) At least one background RoI
    overlaps = entry['max_overlaps']
    # find boxes with sufficient overlap
    fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]
    # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
    bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                       (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
    # image is only valid if such boxes exist
    valid = len(fg_inds) > 0 or len(bg_inds) > 0
    return valid

函数在注释中说明了要返回true的两个条件需要满足至少一个(1)至少背景图中有一个区域是感兴趣的,(2)至少在前端中有一个区域是感兴趣的。

先是提取了entry[‘max_overlaps’]的值,再次说明max_overlaps值是每个类中最大的重叠值
np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0] 表达的是在overlaps中大于前端阈值cfg.TRAIN.FG_THRESH的行号赋值给fg_inds。同理后端背景如bg_inds。最终判图片是否满足这两个条件之一判断图片是否为有效的图片。所有有效的图片都会被保留下来给filtered_roidb.最后把这个过滤过的图像返回。这个是把一些无效的图片过滤掉

回到train_net()函数,经过过滤操作后,此时的roidb已经是经过过滤后的训练数据。
进入到下一步中,引用到的函数为一个类SolverWrapper:
先来分析这个类的初始化函数:

class SolverWrapper(object):
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """

    def __init__(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
            cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
            # RPN can only use precomputed normalization because there are no
            # fixed statistics to compute a priori
            assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

        if cfg.TRAIN.BBOX_REG:
            print 'Computing bounding-box regression targets...'
            self.bbox_means, self.bbox_stds = \
                    rdl_roidb.add_bbox_regression_targets(roidb)
            print 'done'

        self.solver = caffe.SGDSolver(solver_prototxt)
        if pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)

        self.solver.net.layers[0].set_roidb(roidb)

因为没有固定的统计来计算先验,RPN只能使用预先计算的归一化,所以在初始化中第一步是预先计算的归一化。
然后是把rdl_roidb内的一些信息复制给这个类。
设定求解器为SGD,内部的一些设定为solver_prototxt,这个为预先设定好的信息,在程序solvers, max_iters, rpn_test_prototxt = get_solvers(args.net_name)中就预先读取好了。
然后是载入预先训练好的前置模型。
然后设定求解的参数self.solver_param = caffe_pb2.SolverParameter()
最后读取了文件solver_prototxt内的信息。
在定义好sw为类SolverWrapper后,程序调用了类中的子函数,train_model()

def train_model(self, max_iters):
    """Network training loop."""
    last_snapshot_iter = -1
    timer = Timer()
    model_paths = []
    while self.solver.iter < max_iters:
        # Make one SGD update
        timer.tic()
        self.solver.step(1)
        timer.toc()
        if self.solver.iter % (10 * self.solver_param.display) == 0:
            print 'speed: {:.3f}s / iter'.format(timer.average_time)

        if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
            last_snapshot_iter = self.solver.iter
            model_paths.append(self.snapshot())

    if last_snapshot_iter != self.solver.iter:
        model_paths.append(self.snapshot())
    return model_paths

在程序中通过判断一系列条件,比如迭代次数的判断。主要的语句为model_paths.append(self.snapshot()),我们转来看看snapshot函数的内容

def snapshot(self):
    """Take a snapshot of the network after unnormalizing the learned
    bounding-box regression weights. This enables easy use at test-time.
    """
    net = self.solver.net

    scale_bbox_params = (cfg.TRAIN.BBOX_REG and
                         cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
                         net.params.has_key('bbox_pred'))

    if scale_bbox_params:
        # save original values
        orig_0 = net.params['bbox_pred'][0].data.copy()
        orig_1 = net.params['bbox_pred'][1].data.copy()

        # scale and shift with bbox reg unnormalization; then save snapshot
        net.params['bbox_pred'][0].data[...] = \
                (net.params['bbox_pred'][0].data *
                 self.bbox_stds[:, np.newaxis])
        net.params['bbox_pred'][1].data[...] = \
                (net.params['bbox_pred'][1].data *
                 self.bbox_stds + self.bbox_means)

    infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
             if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
    filename = (self.solver_param.snapshot_prefix + infix +
                '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')
    filename = os.path.join(self.output_dir, filename)

    net.save(str(filename))
    print 'Wrote snapshot to: {:s}'.format(filename)

    if scale_bbox_params:
        # restore net to original state
        net.params['bbox_pred'][0].data[...] = orig_0
        net.params['bbox_pred'][1].data[...] = orig_1
    return filename

在程序中,文件保存了未归一化之前的数据,这是为了之后的测试进行准备。然后把许多的参数和设置都保存在net中,路径为设置好的filename。最后在判断条件后把net.params[‘bbox_pred’]的两列数据进行复原。返回的是保存的文件路径,也是模型路径。

最后在train_net中返回了模型的路径model_paths。

回到我们的主文件中的子程序train_rpn()中,第一部分剩下最后一个代码:

# Cleanup all but the final model
for i in model_paths[:-1]:
    os.remove(i)
rpn_model_path = model_paths[-1]
# Send final model path through the multiprocessing queue
queue.put({'model_path': rpn_model_path})

代码显示把前面所有的模型都删除,然后把最后一个模型给了进程queue中的model_path.

感觉写的有点潦草不好理解,这个源代码就这么不好理解········代码写作风格真的是········

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值