读取矢量数据的点坐标集_PyTorch读取目标检测数据集

数据集介绍

一般的目标检测数据集由两部分组成,图片images和标签annotations。由少样本目标检测数据集FSOD为例。

38458bc26560f56a888fed94d34b6090.png
FSOD数据集组成

图片部分就不多介绍了,重点来看一下标记部分annotation,对于图片的标记数据一般用json格式保存。

88bccacf457611dd5f82317cb6e1bf71.png
test.json文件数据

上面的图片是FSOD测试集的标记数据,以字典的形式保存,Keys分别为:

  1. images: Values是一个列表,长度即测试集图片的数量。列表中的每个元素对应一个图片的数据, 例如id, file_name, width, height
  2. type: Values是“instances”
  3. annotations: Values是一个列表,长度即测试集图片的数量。列表中的每个元素对应一个图片的标记数据,例如 ignore, image_id, segmentation, bbox, area, category_id, is_crowed, id
  4. categories: Values是一个列表,长度即数据集类别的数量。列表中的每个元素对应一个类别的数据,例如 supercategory, id, name

代码部分

Class: JsonDataset 这个类用来表示一张图片的标记数据

初始化函数:构造这个类只需要数据集的名称,如"fsod"。

    def __init__(self, name):
        # DATASETS是我们预先写好的字典,保存各种数据集的相关路径
        assert name in DATASETS.keys(), 
            'Unknown dataset name: {}'.format(name)
        assert os.path.exists(DATASETS[name][IM_DIR]), 
            'Image directory '{}' not found'.format(DATASETS[name][IM_DIR])
        assert os.path.exists(DATASETS[name][ANN_FN]), 
            'Annotation file '{}' not found'.format(DATASETS[name][ANN_FN])
        logger.debug('Creating: {}'.format(name))
        self.name = name
        self.image_directory = DATASETS[name][IM_DIR]
        self.image_prefix = (
            '' if IM_PREFIX not in DATASETS[name] else DATASETS[name][IM_PREFIX]
        )
        # 根据标记文件路径构造COCO类
        self.COCO = COCO(DATASETS[name][ANN_FN])
        self.debug_timer = Timer()
        # Set up dataset classes
        # 获取所有的类别id
        category_ids = self.COCO.getCatIds()
        # 根据类别id获取所有类别的名称
        categories = [c['name'] for c in self.COCO.loadCats(category_ids)]
        # 类别名称与id的相互映射
        self.category_to_id_map = dict(zip(categories, category_ids))
        self.id_to_category_map = dict(zip(category_ids, categories))
        self.classes = ['__background__'] + categories
        self.num_classes = len(self.classes)
        # 类别id与其索引的映射
        self.json_category_id_to_contiguous_id = {
            v: i + 1
            for i, v in enumerate(self.COCO.getCatIds())
        }
        # 类别索引与其id的映射
        self.contiguous_category_id_to_json_id = {
            v: k
            for k, v in self.json_category_id_to_contiguous_id.items()
        }

def get_roidb(): 根据图片的标记数据构造数据结构roidb

dc22c1c1a128a82bc709fec830b2fcec.png
roidb的结构
    def get_roidb(self, gt=False, crowd_filter_thresh=0,):
        assert gt is True or crowd_filter_thresh == 0, 
            'Crowd filter threshold must be 0 if ground-truth annotations ' 
            'are not included.'
        # 获取所有图片id并排序
        image_ids = self.COCO.getImgIds()
        image_ids.sort()
        # 根据图片id返回图片信息,格式如下所示:
        # {'id': 2019000089872, 'file_name': 'part_1/n03715892/n03715892_10790.jpg', 'width': 375, 'height': 500}
        if cfg.DEBUG:
            roidb = copy.deepcopy(self.COCO.loadImgs(image_ids))[:100]
        else:
            roidb = copy.deepcopy(self.COCO.loadImgs(image_ids))
        # 填充roidb还缺少的元素构成完整的roidb结构,暂用空值代替
        for entry in roidb:
            self._prep_roidb_entry(entry)
        if gt:
            # Include ground-truth object annotations
            # 如果已经读取过标注文件会有缓存文件
            cache_filepath = os.path.join(self.cache_path, self.name+'_gt_roidb.pkl')
            if os.path.exists(cache_filepath) and not cfg.DEBUG:
                self.debug_timer.tic()
                self._add_gt_from_cache(roidb, cache_filepath)
                logger.debug(
                    '_add_gt_from_cache took {:.3f}s'.
                    format(self.debug_timer.toc(average=False))
                )
            else:
                self.debug_timer.tic()
                for entry in roidb:
                    # 根据标记数据补充roidb中相应的值
                    self._add_gt_annotations(entry)
                logger.debug(
                    '_add_gt_annotations took {:.3f}s'.
                    format(self.debug_timer.toc(average=False))
                )
                if not cfg.DEBUG:
                    with open(cache_filepath, 'wb') as fp:
                        pickle.dump(roidb, fp, pickle.HIGHEST_PROTOCOL)
                    logger.info('Cache ground truth roidb to %s', cache_filepath)
        # 计算max_overlaps: 某张图片每个box所有类别的得分最大值
        #     max_classes:  某张图片每个box得分最高对应的类
        _add_class_assignments(roidb)  

        return roidb

从标记数据中构造roidb数据类型之后,需要对其进行一些简单的修改以便能够更好地适应训练,如数据增强、添加回归目标,删除一些不可用的数据等。

def combined_roidb_for_training(dataset_names, proposal_files):
    def get_roidb(dataset_name, proposal_file):
        ds = JsonDataset(dataset_name)
        roidb = ds.get_roidb(
            gt=True,
            proposal_file=proposal_file,
            crowd_filter_thresh=cfg.TRAIN.CROWD_FILTER_THRESH
        )
        # 数据增强:水平翻转图片
        if cfg.TRAIN.USE_FLIPPED:
            logger.info('Appending horizontally-flipped training examples...')
            extend_with_flipped_entries(roidb, ds)
        logger.info('Loaded dataset: {:s}'.format(ds.name))
        return roidb

    if isinstance(dataset_names, six.string_types):
        dataset_names = (dataset_names, )
    if isinstance(proposal_files, six.string_types):
        proposal_files = (proposal_files, )
    if len(proposal_files) == 0:
        proposal_files = (None, ) * len(dataset_names)
    assert len(dataset_names) == len(proposal_files)
    roidbs = [get_roidb(*args) for args in zip(dataset_names, proposal_files)]
    # 第一个数据集对应的roidb
    original_roidb = roidbs[0]
   
    # new dataset split according to class 
    roidb = []
    for item in original_roidb:
        gt_classes = list(set(item['gt_classes']))
        all_cls = np.array(item['gt_classes'])
        # 遍历一张图上的不同类别
        # 对每个类别对象建立一个entry
        for cls in gt_classes:
            item_new = item.copy()
            target_idx = np.where(all_cls == cls)[0] 
            #item_new['id'] = item_new['id'] * 1000 + int(cls)
            item_new['target_cls'] = int(cls)
            item_new['boxes'] = item_new['boxes'][target_idx]
            item_new['max_classes'] = item_new['max_classes'][target_idx]
            item_new['gt_classes'] = item_new['gt_classes'][target_idx]
            item_new['is_crowd'] = item_new['is_crowd'][target_idx]
            item_new['segms'] = item_new['segms'][:target_idx.shape[0]]
            item_new['seg_areas'] = item_new['seg_areas'][target_idx]
            item_new['max_overlaps'] = item_new['max_overlaps'][target_idx]
            item_new['box_to_gt_ind_map'] = np.array(range(item_new['gt_classes'].shape[0]))
            item_new['gt_overlaps'] = item_new['gt_overlaps'][target_idx]
            roidb.append(item_new)

    for r in roidbs[1:]:
        roidb.extend(r)
    # 移除没有可用roi的entry
    roidb = filter_for_training(roidb)

    if cfg.TRAIN.ASPECT_GROUPING or cfg.TRAIN.ASPECT_CROPPING:
        logger.info('Computing image aspect ratios and ordering the ratios...')
        # 将图片按横纵比进行排序,裁剪较大的图片
        ratio_list, ratio_index, cls_list, id_list = rank_for_training(roidb)
        logger.info('done')
    else:
        ratio_list, ratio_index, cls_list, id_list = None, None, None, None

    logger.info('Computing bounding-box regression targets...')
    add_bbox_regression_targets(roidb)
    logger.info('done')

    _compute_and_log_stats(roidb)

    print(len(roidb))
    return roidb, ratio_list, ratio_index, cls_list, id_list

roidb数据构造完成后,通过DataLoader 生成参与训练的数据形式。

dataset = RoiDataLoader(
        roidb,
        cfg.MODEL.NUM_CLASSES,
        info_list,
        ratio_list,
        training=True)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batchSampler,
        num_workers=0,
        #num_workers=cfg.DATA_LOADER.NUM_THREADS,
        collate_fn=collate_minibatch)
    dataiterator = iter(dataloader)

参考材料:https://github.com/fanq15/FSOD-code

https://blog.csdn.net/qq_34809033/article/details/83215698

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值