【OpenPCDet】DatasetTemplate源码阅读

一、DatasetTemplate简介

DatasetTemplate为openpcdet的数据集模板,openpcdet中的KittiDataset和NuScenesDataset类都继承自该类,在编写新的dataset类时,也应该继承DatasetTemplate类。

二、代码阅读

from collections import defaultdict
from pathlib import Path

import numpy as np
import torch.utils.data as torch_data

from ..utils import common_utils
from .augmentor.data_augmentor import DataAugmentor
from .processor.data_processor import DataProcessor
from .processor.point_feature_encoder import PointFeatureEncoder

""" DatasetTemplate类继承了torch的Dataset类 """
class DatasetTemplate(torch_data.Dataset):
    def __init__(self, dataset_cfg=None, class_names=None, training=True, root_path=None, logger=None):
        super().__init__()
        self.dataset_cfg = dataset_cfg # 传入dataset的config字典
        self.training = training # bool值
        self.class_names = class_names # 分类类别,在kitti中一般为'car',也可以包含其他类
        self.logger = logger # 日志记录
        # 读取字典中DATA_PATH的值作为数据集的根目录,返回Path()对象
        self.root_path = root_path if root_path is not None else Path(self.dataset_cfg.DATA_PATH)
        self.logger = logger
        if self.dataset_cfg is None or class_names is None:
            return

        self.point_cloud_range = np.array(self.dataset_cfg.POINT_CLOUD_RANGE, dtype=np.float32)
        """ 
        kitti_dataset.yaml 中 POINT_FEATURE_ENCODING 如下
			POINT_FEATURE_ENCODING: {
    			encoding_type: absolute_coordinates_encoding,
   				used_feature_list: ['x', 'y', 'z', 'intensity'],
    			src_feature_list: ['x', 'y', 'z', 'intensity'],
			}
		"""
		# 初始化PointFeatureEncoder对象
        self.point_feature_encoder = PointFeatureEncoder(
            self.dataset_cfg.POINT_FEATURE_ENCODING,
            point_cloud_range=self.point_cloud_range
        )
        # 训练模式下,定义数据增强类
        self.data_augmentor = DataAugmentor(
            self.root_path, self.dataset_cfg.DATA_AUGMENTOR, self.class_names, logger=self.logger
        ) if self.training else None
        # 定义数据处理类
        self.data_processor = DataProcessor(
            self.dataset_cfg.DATA_PROCESSOR, point_cloud_range=self.point_cloud_range, training=self.training
        )

        self.grid_size = self.data_processor.grid_size
        self.voxel_size = self.data_processor.voxel_size
        self.total_epochs = 0
        self._merge_all_iters_to_one_epoch = False
	# @property 可以让对象像访问属性一样区访问方法 self.mode 
    @property
    def mode(self):
        return 'train' if self.training else 'test'

    def __getstate__(self):
        d = dict(self.__dict__)
        del d['logger']
        return d
	# 更新成员变量的值
    def __setstate__(self, d):
        self.__dict__.update(d)

	# 自定义数据集时需要实现该方法
    @staticmethod
    def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
        """
        To support a custom dataset, implement this function to receive the predicted results from the model, and then
        transform the unified normative coordinate to your required coordinate, and optionally save them to disk.

        Args:
            batch_dict: dict of original data from the dataloader
            pred_dicts: dict of predicted results from the model
                pred_boxes: (N, 7), Tensor
                pred_scores: (N), Tensor
                pred_labels: (N), Tensor
            class_names:
            output_path: if it is not None, save the results to this path
        Returns:

        """

    def merge_all_iters_to_one_epoch(self, merge=True, epochs=None):
        if merge:
            self._merge_all_iters_to_one_epoch = True
            self.total_epochs = epochs
        else:
            self._merge_all_iters_to_one_epoch = False

    def __len__(self):
        raise NotImplementedError
	# 自定义数据集时实现该方法,加载原始数据和labels,并将这些数据转换到统一的坐标下,调用self.prepare_data()来处理数据和送进模型
    def __getitem__(self, index):
        """
        To support a custom dataset, implement this function to load the raw data (and labels), then transform them to
        the unified normative coordinate and call the function self.prepare_data() to process the data and send them
        to the model.

        Args:
            index:

        Returns:

        """
        raise NotImplementedError
	# 
    def prepare_data(self, data_dict):
        """
        Args:
            data_dict:
                points: (N, 3 + C_in)
                gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
                gt_names: optional, (N), string
                ...

        Returns:
            data_dict:
                frame_id: string
                points: (N, 3 + C_in)
                gt_boxes: optional, (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]
                gt_names: optional, (N), string
                use_lead_xyz: bool
                voxels: optional (num_voxels, max_points_per_voxel, 3 + C)
                voxel_coords: optional (num_voxels, 3)
                voxel_num_points: optional (num_voxels)
                ...
        """
        # 训练模式下,对存在于class_name中的数据进行增强
        if self.training:
            assert 'gt_boxes' in data_dict, 'gt_boxes should be provided for training'
            # 返回一个bool数组,记录自定义数据集中ground_truth_name列表在不在我们需要检测的类别列表self.class_name里面
            # 比如kitti数据集中data_dict['gt_names']=['car','person','cyclist'],self.class_name='car',则gt_boxes_mask=[True, False, False]
            gt_boxes_mask = np.array([n in self.class_names for n in data_dict['gt_names']], dtype=np.bool_)
			# 数据增强 传入字典参数,**data_dict是将data_dict里面的key-value对都拿出来
            data_dict = self.data_augmentor.forward(
                data_dict={
                    **data_dict,
                    'gt_boxes_mask': gt_boxes_mask
                }
            )
            
            if len(data_dict['gt_boxes']) == 0:
                new_index = np.random.randint(self.__len__())
                return self.__getitem__(new_index)
		# 筛选需要检测的gt_boxes
        if data_dict.get('gt_boxes', None) is not None:
        	# 返回data_dict['gt_names']中存在于class_name的下标, 也就是我们一开始指定要检测哪些类,只需要获得这些类的下标
            selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
            # 根据selected,留下我们需要的gt_boxes和gt_names
            data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
            data_dict['gt_names'] = data_dict['gt_names'][selected]
            # 将当帧数据的gt_names中的类别名称对应到class_names的下标
            # 举个栗子,我们要检测的类别class_names = ['car','person'],对于当前帧,类别gt_names = ['car', 'person', 'car', 'car'],当前帧出现了3辆车,一辆单车,获取索引后,gt_classes = [1, 2, 1, 1]
            gt_classes = np.array([self.class_names.index(n) + 1 for n in data_dict['gt_names']], dtype=np.int32)
            # 将类别index信息放到每个gt_boxes的最后
            gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1)
            data_dict['gt_boxes'] = gt_boxes
		# 使用点的哪些属性 比如x,y,z等
        data_dict = self.point_feature_encoder.forward(data_dict)
		# 对点云进行预处理,包括移除超出point_cloud_range的点、 打乱点的顺序以及将点云转换为voxel
        data_dict = self.data_processor.forward(
            data_dict=data_dict
        )
        data_dict.pop('gt_names', None)

        return data_dict

    @staticmethod
    def collate_batch(batch_list, _unused=False):
    	# defaultdict创建一个带有默认返回值的字典,当key不存在时,返回默认值,list默认返回一个空[]
        data_dict = defaultdict(list)
        # 把batch里面的每个sample按照key-value合并
        for cur_sample in batch_list:
            for key, val in cur_sample.items():
                data_dict[key].append(val)
        batch_size = len(batch_list)
        ret = {}

        for key, val in data_dict.items():
            try:
                if key in ['voxels', 'voxel_num_points']:
                    ret[key] = np.concatenate(val, axis=0)
                elif key in ['points', 'voxel_coords']:
                    coors = []
                    for i, coor in enumerate(val):
                        coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)
                        coors.append(coor_pad)
                    ret[key] = np.concatenate(coors, axis=0)
                elif key in ['gt_boxes']:
                    max_gt = max([len(x) for x in val])
                    batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)
                    for k in range(batch_size):
                        batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]
                    ret[key] = batch_gt_boxes3d
                else:
                    ret[key] = np.stack(val, axis=0)
            except:
                print('Error in collate_batch: key=%s' % key)
                raise TypeError

        ret['batch_size'] = batch_size
        return ret

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值