faster rcnn 代码解析 12

lib/fast_rcnn/train.py

# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Train a Fast R-CNN network."""

import caffe
from fast_rcnn.config import cfg
import roi_data_layer.roidb as rdl_roidb
from utils.timer import Timer
import numpy as np
import os
import google.protobuf.text_format
from caffe.proto import caffe_pb2
import google.protobuf as pb2

class SolverWrapper(object):
    #一个关于solver的简单封装,允许我们控制snapshot过程,用于去除归一化学习到bbox回归权重
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over he snapshotting process, which we
    use to unnormalize the learned bounding-box regression weights.
    """

    def __init__(self, solver_prototxt, roidb, output_dir,
                 pretrained_model=None):
        """Initialize the SolverWrapper."""
        self.output_dir = output_dir

        if (cfg.TRAIN.HAS_RPN and cfg.TRAIN.BBOX_REG and
            cfg.TRAIN.BBOX_NORMALIZE_TARGETS):
            # RPN can only use precomputed normalization because there are no
            # fixed statistics to compute a priori
            assert cfg.TRAIN.BBOX_NORMALIZE_TARGETS_PRECOMPUTED

        if cfg.TRAIN.BBOX_REG:
            #在训练RPN时cfg.TRAIN.BBOX_REG=false,在训练fast rcnn时cfg.TRAIN.BBOX_REG=true
            print 'Computing bounding-box regression targets...'
            self.bbox_means, self.bbox_stds = \
                    rdl_roidb.add_bbox_regression_targets(roidb)
            #添加bbox回归器的target属性,并返回Bbox的均值和方差
            print 'done'

        self.solver = caffe.SGDSolver(solver_prototxt)#给caffe的解释器赋值
        if pretrained_model is not None:
            print ('Loading pretrained model '
                   'weights from {:s}').format(pretrained_model)
            self.solver.net.copy_from(pretrained_model)
            #将roidb设置进layer[0],这里的layer[0]就是ROIlayer,其过程调用了
            # layer.py中的set_roidb函数

        self.solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self.solver_param)

        self.solver.net.layers[0].set_roidb(roidb)
        # 所有的前面的数据准备工作都是为了这一句话,将roidb设置进去,接下来就正式进入剖析训练过程的部分了。

    def snapshot(self):#快照方法
        """Take a snapshot of the network after unnormalizing the learned
        bounding-box regression weights. This enables easy use at test-time.
        """
        net = self.solver.net#将snapshot的bb参数去除归一化

        scale_bbox_params = (cfg.TRAIN.BBOX_REG and
                             cfg.TRAIN.BBOX_NORMALIZE_TARGETS and
                             net.params.has_key('bbox_pred'))

        if scale_bbox_params:
            # save original values
            orig_0 = net.params['bbox_pred'][0].data.copy()
            orig_1 = net.params['bbox_pred'][1].data.copy()

            # scale and shift with bbox reg unnormalization; then save snapshot
            #去除归一化,乘标准差,加均值
            net.params['bbox_pred'][0].data[...] = \
                    (net.params['bbox_pred'][0].data *
                     self.bbox_stds[:, np.newaxis])
            net.params['bbox_pred'][1].data[...] = \
                    (net.params['bbox_pred'][1].data *
                     self.bbox_stds + self.bbox_means)

        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
        filename = (self.solver_param.snapshot_prefix + infix +
                    '_iter_{:d}'.format(self.solver.iter) + '.caffemodel')#snapshot命名
        filename = os.path.join(self.output_dir, filename)

        net.save(str(filename))#保存snapshot
        print 'Wrote snapshot to: {:s}'.format(filename)

        if scale_bbox_params:
            # 只是存入的snapshot的bb参数做了去归一化用于测试,但是训练部分仍需要保持归一化的状态
            # restore net to original state
            net.params['bbox_pred'][0].data[...] = orig_0
            net.params['bbox_pred'][1].data[...] = orig_1
        return filename

    def train_model(self, max_iters):
        """Network training loop."""
        last_snapshot_iter = -1
        timer = Timer()
        model_paths = []
        while self.solver.iter < max_iters:
            # Make one SGD update 做一次随机梯度算法
            timer.tic()
            self.solver.step(1)
            timer.toc()
            if self.solver.iter % (10 * self.solver_param.display) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)

            if self.solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0: # 达到预设次数保存snapshot
                last_snapshot_iter = self.solver.iter
                model_paths.append(self.snapshot())

        if last_snapshot_iter != self.solver.iter: # 整体迭代完成后也要存snapshot
            model_paths.append(self.snapshot())
        return model_paths

def get_training_roidb(imdb):
    #产生用于训练的roidb格式的数据,主要实现图片的水平翻转,并添加回去
    """Returns a roidb (Region of Interest database) for use in training."""
    if cfg.TRAIN.USE_FLIPPED:
        #先根据cfg.TRAIN.USE_FLIPPED判断是否需要对roi进行水平镜像翻转(
        # 注意这里的镜像的对称轴是图片的中心线
        print 'Appending horizontally-flipped training examples...'
        imdb.append_flipped_images()
        #向imdb中添加数据,这里是水平镜像翻转的数据,有助于训练和最后的网络效果。然后使用
        # append_flipped_images()添加镜像roi,作者认为这样子能提高最终网络的训练结果
        # (这应该算是一种简单的数据增强吧)
        print 'done'

    print 'Preparing training data...'
    rdl_roidb.prepare_roidb(imdb)
    # 为原始数据集的roidb添加一些说明性的属性,max-overlap,max-classes...
    print 'done'

    return imdb.roidb

def filter_roidb(roidb):
    """
    Remove roidb entries that have no usable RoIs.移除没有可用ROI的roidb输入
    有效的图片必须各有前景和背景ROI
    """
    """
    该函数中定义了一个is_valid函数,用于判断roidb中的每个entry是否合理,合 理定义为至少有一个前景box或背景box。 
    roidb全是groudtruth时,因为box与对应的类的重合度(overlaps)显然为1,也就是说roidb起码要有一个标记类。 
    如果roidb包含了一些proposal,overlaps在[BG_THRESH_LO, BG_THRESH_HI]之间的都将被认为是背景,大于FG_THRESH才被认为是前景,roidb 至少要有一个前景或背景,否则将被过滤掉。 
     将没用的roidb过滤掉以后,返回的就是filtered_roidb
    """

    def is_valid(entry):
        #判断输入图片是否满足条件,对训练图片进行筛选
        # Valid images have:
        #   (1) At least one foreground RoI OR #ROI是候选框
        #   (2) At least one background RoI
        overlaps = entry['max_overlaps']#entry就是经过prepare_roidb()后的roidb数据
        # find boxes with sufficient overlap找到满足条件的boxes
        fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0]#这里的cfg.TRAIN.FG_THRESH =0.5,即产生的proposal与gt的IOU大于等于0.5就选用
        # Select background RoIs as those within [BG_THRESH_LO, BG_THRESH_HI)
        # 寻找在BG_THRESH_LO到BG_THRESH_HI之间的box
        bg_inds = np.where((overlaps < cfg.TRAIN.BG_THRESH_HI) &
                           (overlaps >= cfg.TRAIN.BG_THRESH_LO))[0]
        # image is only valid if such boxes exist
        valid = len(fg_inds) > 0 or len(bg_inds) > 0
        return valid#只要有效就返回1

    num = len(roidb)
    filtered_roidb = [entry for entry in roidb if is_valid(entry)]
    num_after = len(filtered_roidb)
    print 'Filtered {} roidb entries: {} -> {}'.format(num - num_after,
                                                       num, num_after)#num-num_after就是过滤掉的image数
    return filtered_roidb

def train_net(solver_prototxt, roidb, output_dir,
              pretrained_model=None, max_iters=40000):#该函数通过接收不同的solver以及数据进行网络的训练
    """Train a Fast R-CNN network."""

    roidb = filter_roidb(roidb)#判断roidb中的每一个entry是否合理,合理的定义为至少有一个前景box或者后景box
    sw = SolverWrapper(solver_prototxt, roidb, output_dir,
                       pretrained_model=pretrained_model)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)#使用train——model()函数实例化每一层
    print 'done solving'
    return model_paths
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值