数据准备
统计gt_box
作者使用generate_gt_database.py生成储存了数据集所有Car的gt box的信息的文件,包括每个gt box的:
- sample_id:gt box所对应的文件名
- cls_type:gt box的cls type
- gt_box3d:gt box的3D信息
- points:gt box中包含的点云
- intensity:gt box中包含的电云的强度
- obj:这个gt box对应object所有的信息,例如center,size,angle,occlusion,level等
dataset
首先定义kitti_dataset,定义通用接口,初始化data的寻找路径等
# lib/datasets/kitti_dataset.py
class KittiDataset(torch_data.Dataset):
def __init__(self, root_dir, split='train'):
self.split = split
is_test = self.split == 'test'
self.imageset_dir = os.path.join(root_dir, 'KITTI', 'object', 'testing' if is_test else 'training')
split_dir = os.path.join(root_dir, 'KITTI', 'ImageSets', split + '.txt')
self.image_idx_list = [x.strip() for x in open(split_dir).readlines()]
self.num_sample = self.image_idx_list.__len__()
self.image_dir = os.path.join(self.imageset_dir, 'image_2')
self.lidar_dir = os.path.join(self.imageset_dir, 'velodyne')
self.calib_dir = os.path.join(self.imageset_dir, 'calib')
self.label_dir = os.path.join(self.imageset_dir, 'label_2')
self.plane_dir = os.path.join(self.imageset_dir, 'planes')
def get_image(self, idx):
def get_image_shape(self, idx):
def get_lidar(self, idx):
def get_calib(self, idx):
def get_label(self, idx):
def get_road_plane(self, idx):
def __len__(self):
def __getitem__(self, item):
然后定义PointRCNN特殊的dataset,主要是完成提取数据,数据增广等操作。这里主要看准备用于训练rpn的数据。其实代码中的注释已经写的非常好了,这里就直接写一下都做了些什么:
- 读取calib,image_shape,pts
# lib/datasets/ki