数据集介绍
一般的目标检测数据集由两部分组成,图片images
和标签annotations
。由少样本目标检测数据集FSOD为例。
图片部分就不多介绍了,重点来看一下标记部分annotation,对于图片的标记数据一般用json
格式保存。
上面的图片是FSOD测试集的标记数据,以字典的形式保存,Keys分别为:
- images: Values是一个列表,长度即测试集图片的数量。列表中的每个元素对应一个图片的数据, 例如
id, file_name, width, height
- type: Values是“instances”
- annotations: Values是一个列表,长度即测试集图片的数量。列表中的每个元素对应一个图片的标记数据,例如
ignore, image_id, segmentation, bbox, area, category_id, is_crowed, id
- categories: Values是一个列表,长度即数据集类别的数量。列表中的每个元素对应一个类别的数据,例如
supercategory, id, name
代码部分
Class: JsonDataset
这个类用来表示一张图片的标记数据
初始化函数:构造这个类只需要数据集的名称,如"fsod"。
def __init__(self, name):
# DATASETS是我们预先写好的字典,保存各种数据集的相关路径
assert name in DATASETS.keys(),
'Unknown dataset name: {}'.format(name)
assert os.path.exists(DATASETS[name][IM_DIR]),
'Image directory '{}' not found'.format(DATASETS[name][IM_DIR])
assert os.path.exists(DATASETS[name][ANN_FN]),
'Annotation file '{}' not found'.format(DATASETS[name][ANN_FN])
logger.debug('Creating: {}'.format(name))
self.name = name
self.image_directory = DATASETS[name][IM_DIR]
self.image_prefix = (
'' if IM_PREFIX not in DATASETS[name] else DATASETS[name][IM_PREFIX]
)
# 根据标记文件路径构造COCO类
self.COCO = COCO(DATASETS[name][ANN_FN])
self.debug_timer = Timer()
# Set up dataset classes
# 获取所有的类别id
category_ids = self.COCO.getCatIds()
# 根据类别id获取所有类别的名称
categories = [c['name'] for c in self.COCO.loadCats(category_ids)]
# 类别名称与id的相互映射
self.category_to_id_map = dict(zip(categories, category_ids))
self.id_to_category_map = dict(zip(category_ids, categories))
self.classes = ['__background__'] + categories
self.num_classes = len(self.classes)
# 类别id与其索引的映射
self.json_category_id_to_contiguous_id = {
v: i + 1
for i, v in enumerate(self.COCO.getCatIds())
}
# 类别索引与其id的映射
self.contiguous_category_id_to_json_id = {
v: k
for k, v in self.json_category_id_to_contiguous_id.items()
}
def get_roidb(): 根据图片的标记数据构造数据结构roidb
def get_roidb(self, gt=False, crowd_filter_thresh=0,):
assert gt is True or crowd_filter_thresh == 0,
'Crowd filter threshold must be 0 if ground-truth annotations '
'are not included.'
# 获取所有图片id并排序
image_ids = self.COCO.getImgIds()
image_ids.sort()
# 根据图片id返回图片信息,格式如下所示:
# {'id': 2019000089872, 'file_name': 'part_1/n03715892/n03715892_10790.jpg', 'width': 375, 'height': 500}
if cfg.DEBUG:
roidb = copy.deepcopy(self.COCO.loadImgs(image_ids))[:100]
else:
roidb = copy.deepcopy(self.COCO.loadImgs(image_ids))
# 填充roidb还缺少的元素构成完整的roidb结构,暂用空值代替
for entry in roidb:
self._prep_roidb_entry(entry)
if gt:
# Include ground-truth object annotations
# 如果已经读取过标注文件会有缓存文件
cache_filepath = os.path.join(self.cache_path, self.name+'_gt_roidb.pkl')
if os.path.exists(cache_filepath) and not cfg.DEBUG:
self.debug_timer.tic()
self._add_gt_from_cache(roidb, cache_filepath)
logger.debug(
'_add_gt_from_cache took {:.3f}s'.
format(self.debug_timer.toc(average=False))
)
else:
self.debug_timer.tic()
for entry in roidb:
# 根据标记数据补充roidb中相应的值
self._add_gt_annotations(entry)
logger.debug(
'_add_gt_annotations took {:.3f}s'.
format(self.debug_timer.toc(average=False))
)
if not cfg.DEBUG:
with open(cache_filepath, 'wb') as fp:
pickle.dump(roidb, fp, pickle.HIGHEST_PROTOCOL)
logger.info('Cache ground truth roidb to %s', cache_filepath)
# 计算max_overlaps: 某张图片每个box所有类别的得分最大值
# max_classes: 某张图片每个box得分最高对应的类
_add_class_assignments(roidb)
return roidb
从标记数据中构造roidb数据类型之后,需要对其进行一些简单的修改以便能够更好地适应训练,如数据增强、添加回归目标,删除一些不可用的数据等。
def combined_roidb_for_training(dataset_names, proposal_files):
def get_roidb(dataset_name, proposal_file):
ds = JsonDataset(dataset_name)
roidb = ds.get_roidb(
gt=True,
proposal_file=proposal_file,
crowd_filter_thresh=cfg.TRAIN.CROWD_FILTER_THRESH
)
# 数据增强:水平翻转图片
if cfg.TRAIN.USE_FLIPPED:
logger.info('Appending horizontally-flipped training examples...')
extend_with_flipped_entries(roidb, ds)
logger.info('Loaded dataset: {:s}'.format(ds.name))
return roidb
if isinstance(dataset_names, six.string_types):
dataset_names = (dataset_names, )
if isinstance(proposal_files, six.string_types):
proposal_files = (proposal_files, )
if len(proposal_files) == 0:
proposal_files = (None, ) * len(dataset_names)
assert len(dataset_names) == len(proposal_files)
roidbs = [get_roidb(*args) for args in zip(dataset_names, proposal_files)]
# 第一个数据集对应的roidb
original_roidb = roidbs[0]
# new dataset split according to class
roidb = []
for item in original_roidb:
gt_classes = list(set(item['gt_classes']))
all_cls = np.array(item['gt_classes'])
# 遍历一张图上的不同类别
# 对每个类别对象建立一个entry
for cls in gt_classes:
item_new = item.copy()
target_idx = np.where(all_cls == cls)[0]
#item_new['id'] = item_new['id'] * 1000 + int(cls)
item_new['target_cls'] = int(cls)
item_new['boxes'] = item_new['boxes'][target_idx]
item_new['max_classes'] = item_new['max_classes'][target_idx]
item_new['gt_classes'] = item_new['gt_classes'][target_idx]
item_new['is_crowd'] = item_new['is_crowd'][target_idx]
item_new['segms'] = item_new['segms'][:target_idx.shape[0]]
item_new['seg_areas'] = item_new['seg_areas'][target_idx]
item_new['max_overlaps'] = item_new['max_overlaps'][target_idx]
item_new['box_to_gt_ind_map'] = np.array(range(item_new['gt_classes'].shape[0]))
item_new['gt_overlaps'] = item_new['gt_overlaps'][target_idx]
roidb.append(item_new)
for r in roidbs[1:]:
roidb.extend(r)
# 移除没有可用roi的entry
roidb = filter_for_training(roidb)
if cfg.TRAIN.ASPECT_GROUPING or cfg.TRAIN.ASPECT_CROPPING:
logger.info('Computing image aspect ratios and ordering the ratios...')
# 将图片按横纵比进行排序,裁剪较大的图片
ratio_list, ratio_index, cls_list, id_list = rank_for_training(roidb)
logger.info('done')
else:
ratio_list, ratio_index, cls_list, id_list = None, None, None, None
logger.info('Computing bounding-box regression targets...')
add_bbox_regression_targets(roidb)
logger.info('done')
_compute_and_log_stats(roidb)
print(len(roidb))
return roidb, ratio_list, ratio_index, cls_list, id_list
roidb数据构造完成后,通过DataLoader
生成参与训练的数据形式。
dataset = RoiDataLoader(
roidb,
cfg.MODEL.NUM_CLASSES,
info_list,
ratio_list,
training=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=batchSampler,
num_workers=0,
#num_workers=cfg.DATA_LOADER.NUM_THREADS,
collate_fn=collate_minibatch)
dataiterator = iter(dataloader)
参考材料:https://github.com/fanq15/FSOD-code
https://blog.csdn.net/qq_34809033/article/details/83215698