最近在看pvrcnn的源码,加深一下对论文的理解。源码的理解多亏了这位大佬的注释,感谢一下!
首先我们看到train.py函数,在训练之前要对于数据进行一个预处理:
train_set, train_loader, train_sampler = build_dataloader(
dataset_cfg=cfg.DATA_CONFIG,
class_names=cfg.CLASS_NAMES,
batch_size=args.batch_size,
dist=dist_train, workers=args.workers,
logger=logger,
training=True,
merge_all_iters_to_one_epoch=args.merge_all_iters_to_one_epoch,
total_epochs=args.epochs
)
cfg和arg是前面定义的配置文件及变量。
我们进去查看build_dataloader函数做了什么。
build_dataloader定义在_init_.py里面,首先我们在配置文件中选用dataset来进行dataset的初始化操作:
dataset = __all__[dataset_cfg.DATASET](
dataset_cfg=dataset_cfg,
class_names=class_names,
root_path=root_path,
training=training,
logger=logger,
)
__all__ = {
'DatasetTemplate': DatasetTemplate,
'KittiDataset': KittiDataset,
'NuScenesDataset': NuScenesDataset,
'WaymoDataset': WaymoDataset
}
all函数主要是通过配置文件选择我们要进行数据预处理的数据集:kittiDataset
KittiDataset是定义在kitti_dataset.py中的一个类,我们首先对他进行一个初始化:
def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
"""
Args:
root_path:
dataset_cfg:
class_names:
training:
logger:
"""
# 初始化类,将参数赋值给类的属性
super().__init__(
dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
)
# 传递参数是 训练集train 还是验证集val
self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
# root_path的路径是../data/kitti/
# kitti数据集一共三个文件夹“training”和“testing”、“ImageSets”
# 如果是训练集train,将文件的路径指为训练集training ,否则为测试集testing
self.root_split_path = self.root_path / ('training' if self.split != 'test' else 'testing')
# /data/kitti/ImageSets/下面一共三个文件:test.txt , train.txt ,val.txt
# 选择其中的一个文件
split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
# 得到.txt文件下的序列号,组成列表sample_id_list
self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
# 创建用于存放kitti信息的空列表
self.kitti_infos = []
# 调用函数,加载kitti数据,mode的值为:train 或者 test
self.include_kitti_data(self.mode)
主要是根据数据集中的txt文件得到sample_id_list和建立一个kitti_infos的空字典用来后面存储信息。
我们忽略一些辅助函数,直接看到__getitem__函数,主要是读取生成的pkl文件中数据集信息,随后根据info[‘point_cloud’][‘lidar_idx’]确定帧号,进行数据读取和其他info字段的读取初步读取的data_dict,要传入prepare_data(dataset.py父类中定义)进行统一处理,然后即可返回。
前面包含一些相机 lidar坐标系下的图像及点云的标定及转移。后面根据self.prepare_data函数对数据进行增强。
具体做了什么我们可以点进去看一下:
self.prepare_data函数在dataset.py文件中
data_dict = self.data_augmentor.forward(
data_dict={
**data_dict,
'gt_boxes_mask': gt_boxes_mask
}
)
将传入的数据进行打包,丢进数据增强函数中。
forward函数根据配置文件中的配置与名称获取增强器,增强的细节操作定义在data_augmentor里面:
DATA_AUGMENTOR:
DISABLE_AUG_LIST: ['placeholder']
AUG_CONFIG_LIST:
- NAME: gt_sampling
USE_ROAD_PLANE: True
DB_INFO_PATH:
- kitti_dbinfos_train.pkl
PREPARE: {
filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'],
filter_by_difficulty: [-1],
}
SAMPLE_GROUPS: ['Car:20','Pedestrian:15', 'Cyclist:15']
NUM_POINT_FEATURES: 4
DATABASE_WITH_FAKELIDAR: False
REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
LIMIT_WHOLE_SCENE: True
- NAME: random_world_flip
ALONG_AXIS_LIST: ['x']
- NAME: random_world_rotation
WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
- NAME: random_world_scaling
WORLD_SCALE_RANGE: [0.95, 1.05]
随后进行筛选需要检测的gt和需要用到点的哪些属性:
# 筛选需要检测的gt_boxes
if data_dict.get('gt_boxes', None) is not None:
# 返回data_dict[gt_names]中存在于class_name的下标(np.array)
selected = common_utils.keep_arrays_by_name(data_dict['gt_names'], self.class_names)
# 根据selected,选取需要的gt_boxes和gt_names
data_dict['gt_boxes'] = data_dict['gt_boxes'][selected]
data_dict['gt_names'] = data_dict['gt_names'][selected]
# 将当帧数据的gt_names中的类别名称对应到class_names的下标
# 举个栗子,我们要检测的类别class_names = 'car','person'
# 对于当前帧,类别gt_names = 'car', 'person', 'car', 'car',当前帧出现了3辆车,一个人,获取索引后,gt_classes = 1, 2, 1, 1
gt_classes = np.array([self.class_names.index(n) + 1 for n in data_dict['gt_names']], dtype=np.int32)
# 将类别index信息放到每个gt_boxes的最后
gt_boxes = np.concatenate((data_dict['gt_boxes'], gt_classes.reshape(-1, 1).astype(np.float32)), axis=1)
data_dict['gt_boxes'] = gt_boxes
# 如果box2d不同,根据selected,选取需要的box2d
if data_dict.get('gt_boxes2d', None) is not None:
data_dict['gt_boxes2d'] = data_dict['gt_boxes2d'][selected]
# 使用点的哪些属性 比如x,y,z等
if data_dict.get('points', None) is not None:
data_dict = self.point_feature_encoder.forward(data_dict)
然后对点云进行预处理 包括移除超出range的点云,打乱点的顺序和将点云转化为voxel
由此 我们可以看到dataset实际上是创建了一个有关数据集的一个类,同时对这个类进行了数据的预处理。
随后我们初始化dataloader:
dataloader = DataLoader(
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers,
shuffle=(sampler is None) and training, collate_fn=dataset.collate_batch,
drop_last=False, sampler=sampler, timeout=0
)
初始化DataLoader,此时并没有进行数据采样和加载,只有在训练中才会按照batch size调用__getitem__加载数据
在单卡训练中进行,通过DataLoader进行数据加载
返回dataset、dataloader、sampler