3D目标检测——代码理解——Second代码:数据处理kitti_dataset.py的理解

3D目标检测—代码理解—Second代码:数据处理kitti_dataset.py的理解

Second代码的github地址:Second代码的github地址
Second文章的链接:Second文章的链接

目前是刚研究3D点云数据的处理,将自己的理解分享出来,如果有理解有误的地方,还请大家多多批评指正。

  1. dataset.py的部分:
import pathlib
import pickle
import time
from functools import partial

import numpy as np

from second.core import box_np_ops
from second.core import preprocess as prep
from second.data import kitti_common as kitti

REGISTERED_DATASET_CLASSES = {}

def register_dataset(cls, name=None):
    global REGISTERED_DATASET_CLASSES
    if name is None:
        name = cls.__name__
    assert name not in REGISTERED_DATASET_CLASSES, f"exist class: {REGISTERED_DATASET_CLASSES}"
    REGISTERED_DATASET_CLASSES[name] = cls
    return cls

def get_dataset_class(name):
    global REGISTERED_DATASET_CLASSES
    assert name in REGISTERED_DATASET_CLASSES, f"available class: {REGISTERED_DATASET_CLASSES}"
    return REGISTERED_DATASET_CLASSES[name]


class Dataset(object):
    """An abstract class representing a pytorch-like Dataset.
    All other datasets should subclass it. All subclasses should override
    ``__len__``, that provides the size of the dataset, and ``__getitem__``,
    supporting integer indexing in range from 0 to len(self) exclusive.
    这是一个基类,其他的数据集的类要继承该类
    所有子类都应覆盖提供数据集大小的__len__
    和支持从0到len(self)范围内的整数索引的__getitem__。
    """
    NumPointFeatures = -1
    def __getitem__(self, index):
        """This function is used for preprocess.
        you need to create a input dict in this function for network inference.
        format: {
            anchors   框的信息
            voxels      体素的信息
            num_points   点云的数量
            coordinates    坐标的信息
            if training:   如果训练:
                labels             标签的信息
                reg_targets
            [optional]anchors_mask, slow in SECOND v1.5, don't use this.
            [optional]metadata, in kitti, image index is saved in metadata
        }
        该函数为输入数据做准备,input,输入的数据应该是一个字典,
        包含:
        """
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

    def get_sensor_data(self, query):
        """Dataset must provide a unified function to get data.
                将数据集的格式进行统一
        Args:
            query: int or dict. this param must support int for training.
                if dict, should have this format (no example yet): 
                {
                    sensor_name: {
                        sensor_meta
                    }
                }
                if int, will return all sensor data. 
                (TODO: how to deal with unsynchronized data?)
        Returns:
            sensor_data: dict. 
            返回的是传感器的数据,数据的类型是字典
            if query is int (return all), return a dict with all sensors: 
            {
                sensor_name: sensor_data
                ...
                metadata: ... (for kitti, contains image_idx)
            }
            
            如果传感器是雷达
            if sensor is lidar (all lidar point cloud must be concatenated to one array): 
            e.g. If your dataset have two lidar sensor, you need to return a single dict:
            {
                "lidar": {
                    "points": ...
                    ...
                }
            }
            sensor_data: {
                points: [N, 3+]
                [optional]annotations: {
                    "boxes": [N, 7] locs, dims, yaw, in lidar coord system. must tested
                        in provided visualization tools such as second.utils.simplevis
                        or web tool.
                    "names": array of string.
                }
            }
            如果传感器是相机
            if sensor is camera (not used yet):
            sensor_data: {
                data: image string (array is too large)
                [optional]annotations: {
                    "boxes": [N, 4] 2d bbox
                    "names": array of string.
                }
            }
            metadata: {
                # dataset-specific information.
                # for kitti, must have image_idx for label file generation.
                image_idx: ...
            }
            [optional]calib # only used for kitti
        """
        raise NotImplementedError

    def evaluation(self, dt_annos, output_dir):
        """Dataset must provide a evaluation function to evaluate model."""
        raise NotImplementedError

  1. kitti_dataset.py的部分:
from pathlib import Path
import pickle
import time
from functools import partial

import numpy as np

from second.core import box_np_ops
from second.core import preprocess as prep
from second.data import kitti_common as kitti
from second.utils.eval import get_coco_eval_result, get_official_eval_result
from second.data.dataset import Dataset, register_dataset
from second.utils.progress_bar import progress_bar_iter as prog_bar

@register_dataset
class KittiDataset(Dataset):
    NumPointFeatures = 4

    def __init__(self,
                 root_path,
                 info_path,
                 class_names=None,
                 prep_func=None,
                 num_point_features=None):
        
        #如果保存文件的路径是空,直接报错
        assert info_path is not None

        # 把文件的内容读取出来,存在infos里面,再赋值给_kitti_infos
        with open(info_path, 'rb') as f:
            infos = pickle.load(f)
        self._root_path = Path(root_path)
        self._kitti_infos = infos

        print("remain number of infos:", len(self._kitti_infos))
        self._class_names = class_names
        self._prep_func = prep_func

    def __len__(self):
        return len(self._kitti_infos)

    # 该函数作用:将预测得到的信息 转化为 kitti的格式
    def convert_detection_to_kitti_annos(self, detection):
        class_names = self._class_names
        # 把预测:detection中每一帧的内容的det["metadata"]["image_idx"] 提取成一个列表
        det_image_idxes = [det["metadata"]["image_idx"] for det in detection]
        # 把真实:info中每一帧的内容的info["metadata"]["image_idx"] 提取成一个列表
        gt_image_idxes = [info["image"]["image_idx"] for info in self._kitti_infos]

        annos = []
        # 对于每一帧数据进行操作
        for i in range(len(detection)):

            det_idx = det_image_idxes[i]

            #det中存放的是这一帧数据的预测值
            det = detection[i]
            # info中存放的是这一帧数据中的真实值
            info = self._kitti_infos[i]

            # 取出这一帧数据里面的 相机参数
            calib = info["calib"]
            rect = calib["R0_rect"]
            Trv2c = calib["Tr_velo_to_cam"]
            P2 = calib["P2"]

            #将预测得到的参数取出,detach()函数是返回一个新的tensor,
            # 从当前计算图中分离下来的,但是仍指向原变量的存放位置,
            final_box_preds = det["box3d_lidar"].detach().cpu().numpy()
            label_preds = det["label_preds"].detach().cpu().numpy()
            scores = det["scores"].detach().cpu().numpy()

            if final_box_preds.shape[0] != 0:
                # 对数据进行 数据运算
                final_box_preds[:, 2] -= final_box_preds[:, 5] / 2
                # 将box 从雷达坐标系 转化为相机坐标系
                box3d_camera = box_np_ops.box_lidar_to_camera(final_box_preds, rect, Trv2c)
                # 获取在相机坐标系下的 位置、尺寸和转角
                locs = box3d_camera[:, :3]
                dims = box3d_camera[:, 3:6]
                angles = box3d_camera[:, 6]

                # 设置相机box的范围
                camera_box_origin = [0.5, 1.0, 0.5]

                box_corners = box_np_ops.center_to_corner_box3d(locs, dims, angles, camera_box_origin, axis=1)
                box_corners_in_image = box_np_ops.project_to_image(box_corners, P2)
                # box_corners_in_image: [N, 8, 2]
                minxy = np.min(box_corners_in_image, axis=1)
                maxxy = np.max(box_corners_in_image, axis=1)
                # 以上的操作是通过预测的结果,获取二维框的四个坐标信息bbox
                # Bbox中存放的是这一帧中所有物体的预测Bbox的信息值
                bbox = np.concatenate([minxy, maxxy], axis=1)
            
            # 获取一个空字典,用来存放一帧的数据
            anno = kitti.get_start_result_anno()

            num_example = 0
            box3d_lidar = final_box_preds   # 将这一帧中 预测框的3D信息赋值给box3d_lidar

            # 对于这一帧中 每一个物体的3D预测框,
            # 最后得到这一帧中 kitti数据格式下的信息,包含所有物体
            for j in range(box3d_lidar.shape[0]):
                # 将真实的image形状进行赋值
                image_shape = info["image"]["image_shape"]

                #判断预测和真实的Bbox的信息,不符合的进行跳过
                if bbox[j, 0] > image_shape[1] or bbox[j, 1] > image_shape[0]:
                    continue
                if bbox[j, 2] < 0 or bbox[j, 3] < 0:
                    continue

                # 通过比较和计算后,为Bbox赋值
                bbox[j, 2:] = np.minimum(bbox[j, 2:], image_shape[::-1])
                bbox[j, :2] = np.maximum(bbox[j, :2], [0, 0])

                # 将计算得到的信息都添加到 空字典中
                anno["bbox"].append(bbox[j])
                # convert center format to kitti format
                # box3d_lidar[j, 2] -= box3d_lidar[j, 5] / 2
                anno["alpha"].append( -np.arctan2(-box3d_lidar[j, 1], box3d_lidar[j, 0]) + box3d_camera[j, 6])
                anno["dimensions"].append(box3d_camera[j, 3:6])
                anno["location"].append(box3d_camera[j, :3])
                anno["rotation_y"].append(box3d_camera[j, 6])

                anno["name"].append(class_names[int(label_preds[j])])
                anno["truncated"].append(0.0)
                anno["occluded"].append(0)
                anno["score"].append(scores[j])

                num_example += 1
            
            # 如果 该帧中预测的物体个数不为0,则将预测的信息添加到annos中
            # ,否则,添加一个空的信息
            if num_example != 0:
                anno = {n: np.stack(v) for n, v in anno.items()}
                annos.append(anno)
            else:
                annos.append(kitti.empty_result_anno())
            num_example = annos[-1]["name"].shape[0]
            annos[-1]["metadata"] = det["metadata"]
        return annos

    # 评价函数evaluation,需要将格式统一
    def evaluation(self, detections, output_dir):
        """
            detection
            When you want to eval your own dataset, you MUST set correct
            the z axis and box z center.
            If you want to eval by my KITTI eval function, you must 
            provide the correct format annotations.
            ground_truth_annotations format:
            {
                bbox: [N, 4], if you fill fake data, MUST HAVE >25 HEIGHT!!!!!!
                alpha: [N], you can use -10 to ignore it.
                occluded: [N], you can use zero.
                truncated: [N], you can use zero.
                name: [N]
                location: [N, 3] center of 3d box.
                dimensions: [N, 3] dim of 3d box.
                rotation_y: [N] angle.
            }
            all fields must be filled, but some fields can fill zero.
        """
        
        # 如果一帧数据中没有annos,直接返回空,实际上都有
        if "annos" not in self._kitti_infos[0]:
            return None
        
        # 获取真实集中的annos的数据信息
        gt_annos = [info["annos"] for info in self._kitti_infos]
        # 根据预测得到的结果,获取转化为kitti格式后的预测annos的数据信息
        dt_annos = self.convert_detection_to_kitti_annos(detections)
        # firstly convert standard detection to kitti-format dt annos

        z_axis = 1  # KITTI camera format use y as regular "z" axis.
        z_center = 1.0  # KITTI camera box's center is [0.5, 1, 0.5]
        # for regular raw lidar data, z_axis = 2, z_center = 0.5.

        # 下面是通过两种方式进行验证 eval
        result_official_dict = get_official_eval_result(
            gt_annos,
            dt_annos,
            self._class_names,
            z_axis=z_axis,
            z_center=z_center)

        result_coco = get_coco_eval_result(
            gt_annos,
            dt_annos,
            self._class_names,
            z_axis=z_axis,
            z_center=z_center)

        # 返回一个大字典,里面包含两个小字典:    results和detail
        return {
            "results": {
                "official": result_official_dict["result"],
                "coco": result_coco["result"],
            },
            "detail": {
                "eval.kitti": {
                    "official": result_official_dict["detail"],
                    "coco": result_coco["detail"]
                }
            },
        }

    # 根据帧的序列,获取该帧的所有信息,并做进一步处理,传给example
    def __getitem__(self, idx):
        # 根据序列,获取了该帧数据中的全部信息
        input_dict = self.get_sensor_data(idx)
        # ??????? 什么意思??大概是将用于输入的数据进行赋值
        example = self._prep_func(input_dict=input_dict)
        # 将该键中的值设为空,然后再根据条件进行填充
        example["metadata"] = {}
        if "image_idx" in input_dict["metadata"]:
            example["metadata"] = input_dict["metadata"]
        if "anchors_mask" in example:
            example["anchors_mask"] = example["anchors_mask"].astype(np.uint8)
        return example

    # 根据序列号,将数据转化为要求的input格式
    def get_sensor_data(self, query):
        read_image = False
        idx = query
        if isinstance(query, dict):
            # 如果query 是字典类型
            read_image = "cam" in query      # 如果cam是该字典中的键,则返回true,意思是数据集中包含图像
            assert "lidar" in query                     # 如果没有雷达数据lidar,则直接报错
            idx = query["lidar"]["idx"]              # 将数据中的索引号进行赋值,即具体的某一帧的索引值

        # 获取这一帧的数据信息 info
        info = self._kitti_infos[idx]          

        # 定义一个字典,且该字典包含一些属性信息  
        res = {
            "lidar": {
                "type": "lidar",
                "points": None,
            },
            "metadata": {
                "image_idx": info["image"]["image_idx"],
                "image_shape": info["image"]["image_shape"],
            },
            "calib": None,
            "cam": {}
        }

        pc_info = info["point_cloud"]            # 将点云数据赋值出来
        velo_path = Path(pc_info['velodyne_path'])    #获取存有该雷达数据的文件的路径
        # 如果不是绝对路径,则修改为绝对路径(以/开头的路径)
        if not velo_path.is_absolute():
            velo_path = Path(self._root_path) / pc_info['velodyne_path']
        
        # 获取雷达信息的另一个文件(_reduced)路径
        velo_reduced_path = velo_path.parent.parent / (
            velo_path.parent.stem + '_reduced') / velo_path.name
        # 如果该文件存在,则雷达路径velo_path为velo_reduced_path的路径
        if velo_reduced_path.exists():
            velo_path = velo_reduced_path
        
        #通过雷达的文件路径读取信息,并且获得点云数据
        points = np.fromfile(
            str(velo_path), dtype=np.float32,
            count=-1).reshape([-1, self.NumPointFeatures])
        # 将点云数据 赋值到 字典中去
        res["lidar"]["points"] = points

        # 获取这一帧中的image信息,以及相应的文件的路径
        image_info = info["image"]
        image_path = image_info['image_path']

        # 如果读取图像,先获取文件的路径,再读取文件到image_str
        if read_image:
            image_path = self._root_path / image_path
            with open(str(image_path), 'rb') as f:
                image_str = f.read()
            res["cam"] = {
                "type": "camera",
                "data": image_str,
                "datatype": image_path.suffix[1:],   # 获取文件路径名的后缀,得到的结果如:png,jpg之类的
            }
        
        # 获取这一帧中相机参数标定的相关数据,并添加到res的字典里面
        calib = info["calib"]
        calib_dict = {
            'rect': calib['R0_rect'],
            'Trv2c': calib['Tr_velo_to_cam'],
            'P2': calib['P2'],
        }
        res["calib"] = calib_dict

        # 如果info里包含annos:
        if 'annos' in info:
            
            # 获取annos的信息
            annos = info['annos']
            # we need other objects to avoid collision when sample
            # 去掉其中 dontcare的信息
            annos = kitti.remove_dontcare(annos)

            # 获取 位置、尺寸、转向角、类别名字的信息
            locs = annos["location"]
            dims = annos["dimensions"]
            rots = annos["rotation_y"]
            gt_names = annos["name"]

            # rots = np.concatenate([np.zeros([locs.shape[0], 2], dtype=np.float32), rots], axis=1)
            # 将信息转换成 (N,7)的数据格式
            gt_boxes = np.concatenate([locs, dims, rots[..., np.newaxis]],
                                      axis=1).astype(np.float32)
            calib = info["calib"]   # 获取相机标定的一些参数
            # 将相机坐标系下的gt_boxes转换成雷达坐标系下的参数gt_boxes,得到的结果还是(N,7)
            gt_boxes = box_np_ops.box_camera_to_lidar(
                gt_boxes, calib["R0_rect"], calib["Tr_velo_to_cam"])

            # 再对数据做进一步的转化,但该函数没有返回值,所以为啥还要这一步呢???????
            # only center format is allowed. so we need to convert
            # kitti [0.5, 0.5, 0] center to [0.5, 0.5, 0.5]
            box_np_ops.change_box3d_center_(gt_boxes, [0.5, 0.5, 0],
                                            [0.5, 0.5, 0.5])
            
            # 再添加信息到res中,并返回该值,该值中包含了一帧数据中的全部信息
            res["lidar"]["annotations"] = {
                'boxes': gt_boxes,
                'names': gt_names,
            }
            res["cam"]["annotations"] = {
                'boxes': annos["bbox"],
                'names': gt_names,
            }

        return res


# convert kitti info v1 to v2 if possible.
def convert_to_kitti_info_version2(info):
    """convert kitti info v1 to v2 if possible.
    """
    if "image" not in info or "calib" not in info or "point_cloud" not in info:
        info["image"] = {
            'image_shape': info["img_shape"],
            'image_idx': info['image_idx'],
            'image_path': info['img_path'],
        }
        info["calib"] = {
            "R0_rect": info['calib/R0_rect'],
            "Tr_velo_to_cam": info['calib/Tr_velo_to_cam'],
            "P2": info['calib/P2'],
        }
        info["point_cloud"] = {
            "velodyne_path": info['velodyne_path'],
        }


# 将kitti中的一帧的anno信息写成文件保存
def kitti_anno_to_label_file(annos, folder):
    folder = Path(folder)              # 获取文件的路径
    # 对于每一帧的信息:
    for anno in annos:
        image_idx = anno["metadata"]["image_idx"]             # 获取image_idx
        label_lines = []
        for j in range(anno["bbox"].shape[0]):           # 对于该帧数据中的每一个物体object:
            # 获取相应的标签信息,并存入label_dict字典中
            label_dict = {
                'name': anno["name"][j],
                'alpha': anno["alpha"][j],
                'bbox': anno["bbox"][j],
                'location': anno["location"][j],
                'dimensions': anno["dimensions"][j],
                'rotation_y': anno["rotation_y"][j],
                'score': anno["score"][j],
            }

            # 再对该标签信息做进一步的处理,将结果添加到列表label_lines中去
            label_line = kitti.kitti_result_line(label_dict)
            label_lines.append(label_line)
        
        # 将这一帧的所有物体的标签信息写入到文件中,并进行相应的命名
        label_file = folder / f"{kitti.get_image_index_str(image_idx)}.txt"
        label_str = '\n'.join(label_lines)
        with open(label_file, 'w') as f:
            f.write(label_str)


# 根据文件路径读取 文件里的内容,
# 返回结果是一个列表,每个元素是当前帧的一个物体
def _read_imageset_file(path):
    with open(path, 'r') as f:
        lines = f.readlines()
    return [int(line) for line in lines]


def _calculate_num_points_in_gt(data_path,
                                infos,
                                relative_path,
                                remove_outside=True,
                                num_features=4):
    # 对于每一帧数据里的信息
    for info in infos:
        pc_info = info["point_cloud"]        # 获取点云信息
        image_info = info["image"]              #获取图像信息
        calib = info["calib"]                               # 获取相机标定参数
        if relative_path:    # 如果相对路径存在,则变为绝对路径,否则取保存在字典中的路径值
            v_path = str(Path(data_path) / pc_info["velodyne_path"])
        else:               
            v_path = pc_info["velodyne_path"]
        
        # 从文件中读取点云数据,并按照格式排成(M,4)
        points_v = np.fromfile( v_path, dtype=np.float32, count=-1).reshape([-1, num_features])
        rect = calib['R0_rect']
        Trv2c = calib['Tr_velo_to_cam']
        P2 = calib['P2']
        if remove_outside:
            # 判断如果去掉框外面的点,调用函数来实现,去掉外面的点能加快速度
            points_v = box_np_ops.remove_outside_points(
                points_v, rect, Trv2c, P2, image_info["image_shape"])

        # 获取这一帧中所有物体的信息
        annos = info['annos']
        # 得到去掉dontcare的物体的总个数,并取出位置等信息,并将信息组成(N,7)
        num_obj = len([n for n in annos['name'] if n != 'DontCare'])
        # annos = kitti.filter_kitti_anno(annos, ['DontCare'])
        dims = annos['dimensions'][:num_obj]
        loc = annos['location'][:num_obj]
        rots = annos['rotation_y'][:num_obj]
        gt_boxes_camera = np.concatenate([loc, dims, rots[..., np.newaxis]],axis=1)
        # 将数据从相机转为雷达坐标系
        gt_boxes_lidar = box_np_ops.box_camera_to_lidar(gt_boxes_camera, rect, Trv2c)
        
        # 获取点是否在框内的索引,(0,0,0,0,-1,0,,-1,-1)
        indices = box_np_ops.points_in_rbbox(points_v[:, :3], gt_boxes_lidar)
        num_points_in_gt = indices.sum(0)           # 得到在框内的点的个数
        num_ignored = len(annos['dimensions']) - num_obj          # 得到的是无效物体的个数,即dontcare
        num_points_in_gt = np.concatenate([num_points_in_gt, -np.ones([num_ignored])])
        # 将得到框中点的个数,存放在字典中
        annos["num_points_in_gt"] = num_points_in_gt.astype(np.int32)


# 生成数据集信息文件的函数
def create_kitti_info_file(data_path, save_path=None, relative_path=True):
    imageset_folder = Path(__file__).resolve().parent / "ImageSets"
    train_img_ids = _read_imageset_file(str(imageset_folder / "train.txt"))
    val_img_ids = _read_imageset_file(str(imageset_folder / "val.txt"))
    test_img_ids = _read_imageset_file(str(imageset_folder / "test.txt"))

    print("Generate info. this may take several minutes.")
    # 设置保存 路径
    if save_path is None:
        save_path = Path(data_path)
    else:
        save_path = Path(save_path)
    
    # 获取训练集的信息,并将信息保存在文件中
    kitti_infos_train = kitti.get_kitti_image_info(
        data_path,
        training=True,
        velodyne=True,
        calib=True,
        image_ids=train_img_ids,
        relative_path=relative_path)
    _calculate_num_points_in_gt(data_path, kitti_infos_train, relative_path)
    filename = save_path / 'kitti_infos_train.pkl'
    print(f"Kitti info train file is saved to {filename}")
    with open(filename, 'wb') as f:
        pickle.dump(kitti_infos_train, f)

    # 获取验证集的信息,并将数据保存在文件中
    kitti_infos_val = kitti.get_kitti_image_info(
        data_path,
        training=True,
        velodyne=True,
        calib=True,
        image_ids=val_img_ids,
        relative_path=relative_path)
    _calculate_num_points_in_gt(data_path, kitti_infos_val, relative_path)
    filename = save_path / 'kitti_infos_val.pkl'
    print(f"Kitti info val file is saved to {filename}")
    with open(filename, 'wb') as f:
        pickle.dump(kitti_infos_val, f)
    
    # 将train和val的信息保存在同一个文件中
    filename = save_path / 'kitti_infos_trainval.pkl'
    print(f"Kitti info trainval file is saved to {filename}")
    with open(filename, 'wb') as f:
        pickle.dump(kitti_infos_train + kitti_infos_val, f)

    # 将test的数据进行保存
    kitti_infos_test = kitti.get_kitti_image_info(
        data_path,
        training=False,
        label_info=False,
        velodyne=True,
        calib=True,
        image_ids=test_img_ids,
        relative_path=relative_path)
    filename = save_path / 'kitti_infos_test.pkl'
    print(f"Kitti info test file is saved to {filename}")
    with open(filename, 'wb') as f:
        pickle.dump(kitti_infos_test, f)


def _create_reduced_point_cloud(data_path,
                                info_path,
                                save_path=None,
                                back=False):
    # 打开文件,获取数据集的信息
    with open(info_path, 'rb') as f:
        kitti_infos = pickle.load(f)
    # 对于每一帧数据
    for info in prog_bar(kitti_infos):
        # 获取相应的参数和信息
        pc_info = info["point_cloud"]
        image_info = info["image"]
        calib = info["calib"]

        # 获取激光雷达的数据的路径,并读取文件获取点云数据
        v_path = pc_info['velodyne_path']
        v_path = Path(data_path) / v_path
        points_v = np.fromfile( str(v_path), dtype=np.float32, count=-1).reshape([-1, 4])
        rect = calib['R0_rect']
        P2 = calib['P2']
        Trv2c = calib['Tr_velo_to_cam']
        # first remove z < 0 points
        # keep = points_v[:, -1] > 0
        # points_v = points_v[keep]
        # then remove outside.
        if back:
            points_v[:, 0] = -points_v[:, 0]
        # 去掉框外面的点云数据
        points_v = box_np_ops.remove_outside_points(points_v, rect, Trv2c, P2,
                                                    image_info["image_shape"])
        # 设置保存的路径
        if save_path is None:
            save_filename = v_path.parent.parent / (
                v_path.parent.stem + "_reduced") / v_path.name
            # save_filename = str(v_path) + '_reduced'
            if back:
                save_filename += "_back"
        else:
            save_filename = str(Path(save_path) / v_path.name)
            if back:
                save_filename += "_back"
        # 将点云数据写入到要保存的路径中
        with open(save_filename, 'w') as f:
            points_v.tofile(f)


#该函数的作用是去掉数据集中train、val、test 中多余的点云信息
# 主要调用了_create_reduced_point_cloud的函数
def create_reduced_point_cloud(data_path,
                               train_info_path=None,
                               val_info_path=None,
                               test_info_path=None,
                               save_path=None,
                               with_back=False):
    # 获取文件的路径
    if train_info_path is None:
        train_info_path = Path(data_path) / 'kitti_infos_train.pkl'
    if val_info_path is None:
        val_info_path = Path(data_path) / 'kitti_infos_val.pkl'
    if test_info_path is None:
        test_info_path = Path(data_path) / 'kitti_infos_test.pkl'

    # 调用函数,实现去掉多余的数据,并进行保存
    _create_reduced_point_cloud(data_path, train_info_path, save_path)
    _create_reduced_point_cloud(data_path, val_info_path, save_path)
    _create_reduced_point_cloud(data_path, test_info_path, save_path)
    # 如果返回,则将back参数设置为True
    if with_back:
        _create_reduced_point_cloud(
            data_path, train_info_path, save_path, back=True)
        _create_reduced_point_cloud(
            data_path, val_info_path, save_path, back=True)
        _create_reduced_point_cloud(
            data_path, test_info_path, save_path, back=True)


if __name__ == "__main__":
    fire.Fire()
  • 3
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值