【代码阅读】【3d目标检测】pv-rcnn代码阅读(一)数据准备

最近在看pvrcnn的源码,加深一下对论文的理解。源码的理解多亏了这位大佬的注释,感谢一下!
首先我们看到train.py函数,在训练之前要对于数据进行一个预处理:

    train_set, train_loader, train_sampler = build_dataloader(
        dataset_cfg=cfg.DATA_CONFIG,
        class_names=cfg.CLASS_NAMES,
        batch_size=args.batch_size,
        dist=dist_train, workers=args.workers,
        logger=logger,
        training=True,
        merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
        total_epochs=args.epochs
    )

cfg和arg是前面定义的配置文件及变量。
我们进去查看build_dataloader函数做了什么。
build_dataloader定义在_init_.py里面,首先我们在配置文件中选用dataset来进行dataset的初始化操作:

    dataset = __all__[dataset_cfg.DATASET](
        dataset_cfg=dataset_cfg,
        class_names=class_names,
        root_path=root_path,
        training=training,
        logger=logger,
    )

__all__ = {
    'DatasetTemplate': DatasetTemplate,
    'KittiDataset': KittiDataset,
    'NuScenesDataset': NuScenesDataset,
    'WaymoDataset': WaymoDataset
}

all函数主要是通过配置文件选择我们要进行数据预处理的数据集:kittiDataset
KittiDataset是定义在kitti_dataset.py中的一个类,我们首先对他进行一个初始化:

    def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
        """
        Args:
            root_path:
            dataset_cfg:
            class_names:
            training:
            logger:
        """
        # 初始化类,将参数赋值给类的属性
        super().__init__(
            dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
        )
        # 传递参数是 训练集train 还是验证集val
        self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
        # root_path的路径是../data/kitti/
        # kitti数据集一共三个文件夹“training”和“testing”、“ImageSets”
        # 如果是训练集train,将文件的路径指为训练集training ,否则为测试集testing
        self.root_split_path = self.root_path / ('training' if self.split != 'test' else 'testing')
        # /data/kitti/ImageSets/下面一共三个文件:test.txt , train.txt ,val.txt
        # 选择其中的一个文件
        split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
        # 得到.txt文件下的序列号,组成列表sample_id_list
        self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
        # 创建用于存放kitti信息的空列表
        self.kitti_infos = []
        # 调用函数,加载kitti数据,mode的值为:train 或者  test
        self.include_kitti_data(self.mode)

主要是根据数据集中的txt文件得到sample_id_list和建立一个kitti_infos的空字典用来后面存储信息。
我们忽略一些辅助函数,直接看到__getitem__函数,主要是读取生成的pkl文件中数据集信息,随后根据info[‘point_cloud’][‘lidar_idx’]确定帧号,进行数据读取和其他info字段的读取初步读取的data_dict,要传入prepare_data(dataset.py父类中定义)进行统一处理,然后即可返回。
前面包含一些相机 lidar坐标系下的图像及点云的标定及转移。后面根据self.prepare_data函数对数据进行增强。
具体做了什么我们可以点进去看一下:
self.prepare_data函数在dataset.py文件中

            data_dict = self.data_augmentor.forward(
                data_dict={
                    **data_dict,
                    'gt_boxes_mask': gt_boxes_mask
                }
            )

将传入的数据进行打包,丢进数据增强函数中。
forward函数根据配置文件中的配置与名称获取增强器,增强的细节操作定义在data_augmentor里面:

DATA_AUGMENTOR:
    DISABLE_AUG_LIST: ['placeholder']
    AUG_CONFIG_LIST:
        - NAME: gt_sampling
          USE_ROAD_PLANE: True
          DB_INFO_PATH:
              - kitti_dbinfos_train.pkl
          PREPARE: {
             filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'],
             filter_by_difficulty: [-1],
          }

          SAMPLE_GROUPS: ['Car:20','Pedestrian:15', 'Cyclist:15']
          NUM_POINT_FEATURES: 4
          DATABASE_WITH_FAKELIDAR: False
          REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
          LIMIT_WHOLE_SCENE: True

        - NAME: random_world_flip
          ALONG_AXIS_LIST: ['x']

        - NAME: random_world_rotation
          WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]

        - NAME: random_world_scaling
          WORLD_SCALE_RANGE: [0.95, 1.05]


随后进行筛选需要检测的gt和需要用到点的哪些属性:

        # 筛选需要检测的gt_boxes
        if data_dict.get('gt_boxes', None) is not None:
            # 返回data_dict[gt_names]中存在于class_name的下标(np.array)
            selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
            # 根据selected,选取需要的gt_boxes和gt_names
            data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
            data_dict['gt_names'] = data_dict['gt_names'][selected]
            # 将当帧数据的gt_names中的类别名称对应到class_names的下标
            # 举个栗子,我们要检测的类别class_names = 'car','person'
            # 对于当前帧,类别gt_names = 'car', 'person', 'car', 'car',当前帧出现了3辆车,一个人,获取索引后,gt_classes = 1, 2, 1, 1
            gt_classes = np.array([self.class_names.index(n) + 1 for n in data_dict['gt_names']], dtype=np.int32)
            # 将类别index信息放到每个gt_boxes的最后
            gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1)
            data_dict['gt_boxes'] = gt_boxes

            # 如果box2d不同,根据selected,选取需要的box2d
            if data_dict.get('gt_boxes2d', None) is not None:
                data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected]

        # 使用点的哪些属性 比如x,y,z等
        if data_dict.get('points', None) is not None:
            data_dict = self.point_feature_encoder.forward(data_dict)

然后对点云进行预处理 包括移除超出range的点云,打乱点的顺序和将点云转化为voxel
由此 我们可以看到dataset实际上是创建了一个有关数据集的一个类,同时对这个类进行了数据的预处理。
随后我们初始化dataloader:

    dataloader = DataLoader(
        dataset, batch_size=batch_size, pin_memory=True, num_workers=workers,
        shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch,
        drop_last=False, sampler=sampler, timeout=0
    )

初始化DataLoader,此时并没有进行数据采样和加载,只有在训练中才会按照batch size调用__getitem__加载数据
在单卡训练中进行,通过DataLoader进行数据加载
返回dataset、dataloader、sampler

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值