通过入口文件trainval_net.py找到数据集相关部分,发现首先通过输入参数args.dataset确定参数args.imdb_name、args_imdbval_name、args.set_cfgs,之后会对这些参数进行处理。处理过程如下:
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)
train_size = len(roidb) sampler_batch = sampler(train_size, args.batch_size)
dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
imdb.num_classes, training=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,sampler=sampler_batch, num_workers=args.num_workers)
下面分别对数据处理部分详细剖析:
1.imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)
函数combined_roidb():返回数据集、数据集中的ROI、ROI的ratio_list、ROI的ratio_index。那么各个功能到底是如何实现的呢?进入conbined_roidb()中一探究竟。从后往前推。代码如下:
def combined_roidb(imdb_names, training=True):
"""
Combine multiple roidbs
"""
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')
print('Preparing training data...')
prepare_roidb(imdb)
# ratio_index = rank_roidb_ratio(imdb)
print('done')
return imdb.roidb
def get_roidb(imdb_name):
#get_imdb()通过名字返回图像数据库
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
roidb = get_training_roidb(imdb)
return roidb
#这句的意思就是数据源可能是从多个源头进行导入的,所以假如真的是从多个数据源进行导入,
#则用加号把各种数据集连起来,到了用到的时候再用split函数把各种数据集的名字分开。
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
#但事实上,程序中只用到了一个数据集,所以是:
roidb = roidbs[0]
#如果是多个数据集的话:
if len(roidbs) > 1:
for r in roidbs[1:]:
roidb.extend(r)
tmp = get_imdb(imdb_names.split('+')[1])
imdb = datasets.imdb.imdb(imdb_names, tmp.classes)
else:
imdb = get_imdb(imdb_names)
if training:
roidb = filter_roidb(roidb)
ratio_list, ratio_index = rank_roidb_ratio(roidb)
return imdb, roidb, ratio_list, ratio_index
(1)imdb:imdb = get_imdb(imdb_names)
(2)roidb:
roidbs = [get_roidb(s) for s in imdb_names.split('+')]
#但事实上,程序中只用到了一个数据集,所以是只取roidbs[0]
roidb = roidbs[0]
1)实际上roidb由get_roidb()得到:
def get_roidb(imdb_name):
#get_imdb()通过名字返回图像数据库
imdb = get_imdb(imdb_name)
print('Loaded dataset `{:s}` for training'.format(imdb.name))
imdb.set_proposal_method(cfg.TRAIN.PROPOSAL_METHOD)
print('Set proposal method: {:s}'.format(cfg.TRAIN.PROPOSAL_METHOD))
#关键语句:由imdb得到roidb,那么到底是如何得到的呢?由函数get_training_roidb()得到。
roidb = get_training_roidb(imdb)
return roidb
1.1)get_training_roidb():
def get_training_roidb(imdb):
"""Returns a roidb (Region of Interest database) for use in training."""
if cfg.TRAIN.USE_FLIPPED:
print('Appending horizontally-flipped training examples...')
imdb.append_flipped_images()
print('done')
print('Preparing training data...')
prepare_roidb(imdb)
# ratio_index = rank_roidb_ratio(imdb)
print('done')
return imdb.roidb
1.1.1) prepare_roidb(imdb):
def prepare_roidb(imdb):
"""Enrich the imdb's roidb by adding some derived quantities that
are useful for training. This function precomputes the maximum
overlap, taken over ground-truth boxes, between each ROI and
each ground-truth box. The class with maximum overlap is also
recorded.
"""
roidb = imdb.roidb
if not (imdb.name.startswith('coco')):
sizes = [PIL.Image.open(imdb.image_path_at(i)).size
for i in range(imdb.num_images)]
for i in range(len(imdb.image_index)):
roidb[i]['img_id'] = imdb.image_id_at(i)
roidb[i]['image'] = imdb.image_path_at(i)
if not (imdb.name.startswith('coco')):
roidb[i]['width'] = sizes[i][0]
roidb[i]['height'] = sizes[i][1]
# need gt_overlaps as a dense array for argmax
gt_overlaps = roidb[i]['gt_overlaps'].toarray()
# max overlap with gt over classes (columns)
max_overlaps = gt_overlaps.max(axis=1)
# gt class that had the max overlap
max_classes = gt_overlaps.argmax(axis=1)
roidb[i]['max_classes'] = max_classes
roidb[i]['max_overlaps'] = max_overlaps
# sanity checks
# max overlap of 0 => class should be zero (background)
zero_inds = np.where(max_overlaps == 0)[0]
assert all(max_classes[zero_inds] == 0)
# max overlap > 0 => class should not be zero (must be a fg class)
nonzero_inds = np.where(max_overlaps > 0)[0]
assert all(max_classes[nonzero_inds] != 0)