Faster R-CNN 训练过程源码理解
训练脚本 ./tools/train_net.py 主函数开始.
数据读取层 RoIDataLayer
首先,
imdb, roidb = combined_roidb(args.imdb_name) # 输入参数 imdb_name,默认是 voc_2007_trainval(数据集名字)
print '{:d} roidb entries'.format(len(roidb))
然后,函数 combined_roidb:
def combined_roidb(imdb_names):
def get_roidb(imdb_name):
imdb = get_imdb(imdb_name) # factory.py 中的函数,调用的是 pascal_voc 的数据集对象
# get_imdb 默认返回的是 pascal_voc('trainval', '2007')
# 设置imdb的一些属性,如图片路径,图片名称索引等,未读取真正的图片数据
print 'Loaded dataset `{:s}` for training'.format(imdb.name)
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
roidb = get_training_roidb(imdb)
return roidb
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
# imdb_names.split('+') 默认值是 voc_2007_trainval
# 需要调用内部函数 get_roidb
roidb = roidbs[0]
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
imdb = datasets.imdb.imdb(imdb_names)
else:
imdb = get_imdb(imdb_names)
return imdb, roidb
pascal_voc 数据集对应的类的对象:
class pascal_voc(imdb): # 继承于 imdb 类的子类
def __init__(self, image_set, year, devkit_path=None):
imdb.__init__(self, 'voc_' + year + '_' + image_set)
self._year = year
self._image_set = image_set
self._devkit_path = '/data/VOCdevkit'
self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
self._classes = ('__background__', # always index 0
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
self._image_ext = '.jpg'
self._image_index = self._load_image_set_index()
# Default to roidb handler
self._roidb_handler = self.selective_search_roidb
self._salt = str(uuid.uuid4()) # ?
self._comp_id = 'comp4' # ?
# PASCAL specific config options
self.config = {
'cleanup' : True,
'use_salt' : True,
'use_diff' : False,
'matlab_eval' : False,
'rpn_file' : None,
'min_size' : 2}
assert os.path.exists(self._devkit_path), \
'VOCdevkit path does not exist: {}'.format(self._devkit_path)
assert os.path.exists(self._data_path), \
'Path does not exist: {}'.format(self._data_path)
#
class imdb(object):
"""Image database."""
def __init__(self, name):
self._name = name
self._num_classes = 0
self._classes = []
self._image_index = []
self._obj_proposer = 'selective_search'
self._roidb = None
self._roidb_handler = self.default_roidb
# Use this dict for storing dataset specific config options
self.config = {}
得到的 imdb = pascal_voc(‘trainval’, ‘2007’) 记录的内容如下:
[1] - _class_to_ind,dict 类型,key 是类别名,value 是 label 值(从 0 开始),其中 (key[0], value[0]) = [background, 0]
[2] - _classes,object 类别名,共 20(object classes) + 1(background) = 21 classes.
[3] - _data_path,数据集路径
[4] - _image_ext,’.jpg’ 数据类型
[5] - _image_index,图片索引列表
[6] - _image_set,’trainval’
[7] - _name,数据集名称 voc_2007_trainval
[8] - _num_classes,0
[9] - _obj_proposer,selective_search
[10] - _roidb,None
[11] - classes,与_classes 相同
[12] - image_index,与_image_index 相同
[13] - name,数据集名称,与 _name 相同
[14] - num_classes,类别数,21
[15] - num_images,图片数
[16] - config,dict 类型,PASCAL 数据集指定的配置
读取 imdb 后,是 ,
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print 'Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD)
config.py 中 cfg.TRAIN.PROPOSAL_METHOD 值为 selective_search
experiments/cfgs/faster_rcnn_end2end.yml 中 cfg.TRAIN.PROPOSAL_METHOD 值为 gt
set_proposal_method 函数,
def set_proposal_method(self, method):
method = eval('self.' + method + '_roidb') # eval 函数把字符串转成表达式,self.gt_roidb/pascal_voc 内的函数
self.roidb_handler = method
def gt_roidb(self):
"""
Return the database of ground-truth regions of interest.
This function loads/saves from/to a cache file to speed up future calls.
"""
cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
if os.path.exists(cache_file):
with open(cache_file, 'rb') as fid:
roidb = cPickle.load(fid)
print '{} gt roidb loaded from {}'.format(self.name, cache_file)
return roidb
gt_roidb = [self._load_pascal_annotation(index)
for index in self.image_index]
with open(cache_file, 'wb') as fid:
cPickle.dump(gt_roidb, fid, cPickle.HIGHEST_PROTOCOL)
print 'wrote gt roidb to {}'.format(cache_file)
return gt_roidb
def _load_pascal_annotation(self, index):
"""
Load image and bounding boxes info from XML file in the PASCAL VOC
format.
"""
filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
tree = ET.parse(filename)
objs = tree.findall('object')
if not self.config['use_diff']:
# Exclude the samples labeled as difficult
non_diff_objs = [
obj for obj in objs if int(obj.find(