CenterNet代码之datasets

CenterNet(Objects as points)开源代码:https://github.com/xingyizhou/CenterNet

源码的dataset结构如下:

datasets  
  |
  |---dataset  # 解析各数据集(CenterNet共用了下面的数据集)
        |---coco.py     # Coco数据集
        |---coco_hp.py  # Coco human pose
        |---kitti.py    # kitti
        |---pascal.py   # PascalVOC
  |
  |---sample   # 针对不同的网络 提取所需数据
        |---ctdet.py       # CenterNet
        |---ddd.py         # 3D Detection
        |---exdet.py       # ExtremeNet
        |---multi_pose.py  # 
  |
  |---data_factory.py  # 整合dataset和sample,构建完整的pipeline

该结构这样设计的目的是拆分和精细化每一个步骤,看过论文的知道,CenterNet可以很好地在目标检测、3D检测、人体姿态等任务上迁移,所以作者这样设计datasets更方便我们随意结合,同时,如果我们想使用自己的数据集也会很方便。

下文是详细解释,只想直接用懒得细看请移步:https://blog.csdn.net/weixin_43509263/article/details/100799415

我的任务是目标检测,采用Ccco数据集,使用CenterNet,所以简化文件结构,保留如下:

datasets  
  |---dataset  
        |---coco.py
  |---sample
        |---ctdet.py   # CenterNet
  |---data_factory.py  

'''
实际上,一般构建Dataset我们都会继承torch.utils.data.Dataset, 
       一般都会重写__init__ 、__getitem__ 和 __len__ 三个函数,
这里,__init__、__len__在dataset实现,而 __getitem__在sample中
'''
  • dataset中coco.py解析coco数据集: 
"""
    对coco数据集进行解析
    
    def __init__(self, opt, split): 解析数据集中各属性
    def __len__(self): 返回样本数

    def run_eval(self, results, save_dir): eval接口
       \-- def save_results(self, results, save_dir): 保存结果
              \-- def convert_eval_format(self, all_bboxes): 将自己的结果 转换成coco要求的验证格式
"""
import pycocotools.coco as coco
import pycocotools.cocoeval as COCOeval
import numpy as np
import json
import os

import torch.utils.data as data

class COCO(data.Dataset): 

    num_classes = 80
    default_resolution = [512, 512]
    mean = np.array([0.40789654, 0.44719302, 0.47026115],
                    dtype=np.float32).reshape(1, 1, 3)
    std = np.array([0.28863828, 0.27408164, 0.27809835],
                   dtype=np.float32).reshape(1, 1, 3)

    def __init__(self, opt, split):
        '''
        :param opt: opt是传入的参数对象,在opt.py中
        :param split: train\val\test
        '''
        super(COCO, self).__init__()
        ## self.data_dir、img_dir、annot_dir
        self.data_dir = os.path.join(opt.data_dir, 'coco')
        self.img_dir = os.path.join(self.data_dir, '{}2017'.format(split))
        if split == 'test':
            self.annot_path = os.path.join(
                self.data_dir, 'annotations', 'image_info_test-dec2017.json')
        else:
            self.annot_path = os.path.join(
                    self.data_dir, 'annotations',
                    'instances_{}2017.json').format(split)

        ''' ???????????????????/ '''
        self.max_objs = 128
        # 类别名 加上__background__共81个
        self.class_name = [
            '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
            'bus', 'train
  • 3
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值