目标检测 - Faster R-CNN 训练过程源码理解

本文深入解析Faster R-CNN的训练过程,从数据读取、RoIDataLayer到训练网络,详述了数据准备、RoI处理和mini-batch的构建步骤。
摘要由CSDN通过智能技术生成

Faster R-CNN 训练过程源码理解

训练脚本 ./tools/train_net.py 主函数开始.

数据读取层 RoIDataLayer

首先,

imdb, roidb = combined_roidb(args.imdb_name) # 输入参数 imdb_name,默认是 voc_2007_trainval(数据集名字)
print '{:d} roidb entries'.format(len(roidb))

然后,函数 combined_roidb:

def combined_roidb(imdb_names):
    def get_roidb(imdb_name):
        imdb = get_imdb(imdb_name) # factory.py 中的函数,调用的是 pascal_voc 的数据集对象
        # get_imdb 默认返回的是 pascal_voc('trainval', '2007')
        # 设置imdb的一些属性,如图片路径,图片名称索引等,未读取真正的图片数据

        print 'Loaded dataset `{:s}` for training'.format(imdb.name)
        imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
        print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
        roidb = get_training_roidb(imdb)
        return roidb

    roidbs = [get_roidb(s) for s in imdb_names.split('+')] 
    # imdb_names.split('+') 默认值是 voc_2007_trainval 
    # 需要调用内部函数 get_roidb

    roidb = roidbs[0]
    if len(roidbs) > 1:
        for r in roidbs[1:]:
            roidb.extend(r)
        imdb = datasets.imdb.imdb(imdb_names)
    else:
        imdb = get_imdb(imdb_names)
    return imdb, roidb

pascal_voc 数据集对应的类的对象:

class pascal_voc(imdb):  # 继承于 imdb 类的子类
    def __init__(self, image_set, year, devkit_path=None):
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = '/data/VOCdevkit'
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__', # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        self._roidb_handler = self.selective_search_roidb
        self._salt = str(uuid.uuid4()) # ?
        self._comp_id = 'comp4' # ?

        # PASCAL specific config options
        self.config = {
  'cleanup'     : True,
                       'use_salt'    : True,
                       'use_diff'    : False,
                       'matlab_eval' : False,
                       'rpn_file'    : None,
                       'min_size'    : 2}

        assert os.path.exists(self._devkit_path), \
                'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
                'Path does not exist: {}'.format(self._data_path)



#
class imdb(object):
    """Image database."""

    def __init__(self, name):
        self._name = name
        self._num_classes = 0
        self._classes = []
        self._image_index = []
        self._obj_proposer = 'selective_search'
        self._roidb = None
        self._roidb_handler = self.default_roidb
        # Use this dict for storing dataset specific config options
        self.config = {}

得到的 imdb = pascal_voc(‘trainval’, ‘2007’) 记录的内容如下:

[1] - _class_to_ind,dict 类型,key 是类别名,value 是 label 值(从 0 开始),其中 (key[0], value[0]) = [background, 0]

[2] - _classes,object 类别名,共 20(object classes) + 1(background) = 21 classes.

[3] - _data_path,数据集路径

[4] - _image_ext,’.jpg’ 数据类型

[5] - _image_index,图片索引列表

[6] - _image_set,’trainval’

[7] - _name,数据集名称 voc_2007_trainval

[8] - _num_classes,0

[9] - _obj_proposer,selective_search

[10] - _roidb,None

[11] - classes,与_classes 相同

[12] - image_index,与_image_index 相同

[13] - name,数据集名称,与 _name 相同

[14] - num_classes,类别数,21

[15] - num_images,图片数

[16] - config,dict 类型,PASCAL 数据集指定的配置

读取 imdb 后,是 ,

imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)

config.py 中 cfg.TRAIN.PROPOSAL_METHOD 值为 selective_search

experiments/cfgs/faster_rcnn_end2end.yml 中 cfg.TRAIN.PROPOSAL_METHOD 值为 gt

set_proposal_method 函数,

def set_proposal_method(self, method):
    method = eval('self.' + method + '_roidb') # eval 函数把字符串转成表达式,self.gt_roidb/pascal_voc 内的函数 
    self.roidb_handler = method
    def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = cPickle.load(fid)
            print '{} gt roidb loaded from {}'.format(self.name, cache_file)
            return roidb

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
        print 'wrote gt roidb to {}'.format(cache_file)

        return gt_roidb


    def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        if not self.config['use_diff']:
            # Exclude the samples labeled as difficult
            non_diff_objs = [
                obj for obj in objs if int(obj.find(
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值