一、引言
实现不同距离下的检测,是为了观察中远距离下的目标检测能力。其中在VirConv中就专门检测了虚拟点对远处稀疏物体的性能:
但是,源码中却没有给出明确的测试方法,因此本文就该问题简单复现一下。
框架:OpenPCDet
二、复现代码及解析
1、源码参考
该github中已经给出了明确的测试代码,我们就简单对预测结果result.pkl以及gt的pkl更改即可。
2、更改源码
(1)复现GitHub中的结果
环境搭建未记录,因此先略过。。
首先,他们的github的README.md中已经给出了一个样例,只需要照猫画虎将这段代码复制粘贴到evaluate.py中的main中就可以。
import kitti_common as kitti
from eval import get_official_eval_result, get_coco_eval_result
def _read_imageset_file(path):
with open(path, 'r') as f:
lines = f.readlines()
return [int(line) for line in lines]
det_path = "/path/to/your_result_folder" # 预测结果的路径pkl
dt_annos = kitti.get_label_annos(det_path)
gt_path = "/path/to/your_gt_label_folder" # 真实标签的路径pkl
gt_split_file = "/path/to/val.txt" # from https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz
val_image_ids = _read_imageset_file(gt_split_file)
gt_annos = kitti.get_label_annos(gt_path, val_image_ids)
print(get_official_eval_result(gt_annos, dt_annos, 0)) # KITTI官方评估方法,6s in my computer
print(get_coco_eval_result(gt_annos, dt_annos, 0)) # COCO评估方法,18s in my computer
(2)生成result.pkl
一般我们在OpenPCDet训练和推理完后,都会在结果目录下生成一个result.pkl,例如我的就在:OpenPCDet/output/models/kitti/VirConv-S-train/default/eval/eval_with_train/epoch_5/val下。
(3)读取预测结果result.pkl和gt的pkl
为了实现方便,我自己写了一个pkl的读取函数以及预处理方式,我们主要读取的是result.pkl和kitti_infos_val.pkl。
###############################形参设置################################
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument('--save_path', type=str, default="/path/to/your_result_folder",
help='specify the config for training')
parser.add_argument('--gt_split_file', type=str, default="/path/to/val.txt",
help='specify the config for training')
parser.add_argument('--pred_infos', type=str, default=None, help='pickle file')
parser.add_argument('--gt_infos', type=str, default=None, help='pickle file')
parser.add_argument('--sampled_interval', type=int, default=1, help='sampled interval for GT sequences')
args = parser.parse_args()
##############################pkl读取&预处理########################
pred_infos = pickle.load(open(args.pred_infos, 'rb')) # 使用pickle读取预测infos
gt_infos = pickle.load(open(args.gt_infos, 'rb')) # 使用pickle读取gt box
gt_infos_dst = []
for idx in range(0, len(gt_infos), args.sampled_interval): # gt预处理
cur_info = gt_infos[idx]['annos']
# cur_info['frame_id'] = gt_infos[idx]['annos']
cur_info = drop_info_with_name(cur_info, name='DontCare') # discard DontCare
gt_names = cur_info['name']
cur_info['name'] = np.array(['Car' if gt_names[i] == 'Van' else gt_names[i] for i in range(len(gt_names))]) # 将Van转换成Car
cur_info['frame_id'] = gt_infos[idx]['point_cloud']['lidar_idx']
gt_infos_dst.append(cur_info)
(4)改写评估函数
评估函数在eval.py里,其中有一个get_official_eval_result_by_distance:
def get_official_eval_result_by_distance(gt_annos, dt_annos, current_classes, PR_detail_dict=None):
print("Evaluating kitti by distance")
overlap_0_7 = np.array([[0.7, 0.5, 0.5, 0.7,
0.5, 0.7], [0.7, 0.5, 0.5, 0.7, 0.5, 0.7],
[0.7, 0.5, 0.5, 0.7, 0.5, 0.7]])
overlap_0_5 = np.array([[0.7, 0.5, 0.5, 0.7,
0.5, 0.5], [0.5, 0.25, 0.25, 0.5, 0.25, 0.5],
[0.5, 0.25, 0.25, 0.5, 0.25, 0.5]])
min_overlaps = np.stack([overlap_0_7, overlap_0_5], axis=0) # [2, 3, 5]
class_to_name = {
0: 'Car',
1: 'Pedestrian',
2: 'Cyclist',
3: 'Van',
4: 'Person_sitting',
5: 'Truck'
}
name_to_class = {v: n for n, v in class_to_name.items()}
if not isinstance(current_classes, (list, tuple)):
current_classes = [current_classes]
current_classes_int = []
for curcls in current_classes:
if isinstance(curcls, str):
current_classes_int.append(name_to_class[curcls])
else:
current_classes_int.append(curcls)
current_classes = current_classes_int
min_overlaps = min_overlaps[:, :, current_classes]
result = ''
# check whether alpha is valid
compute_aos = False
for anno in dt_annos:
if anno['alpha'].shape[0] != 0:
if anno['alpha'][0] != -10:
compute_aos = True
break
mAPbbox, mAPbev, mAP3d, mAPaos, mAPbbox_R40, mAPbev_R40, mAP3d_R40, mAPaos_R40 = do_eval(
gt_annos, dt_annos, current_classes, min_overlaps, compute_aos, PR_detail_dict=PR_detail_dict, DIForDIS=False)
ret_dict = {}
for j, curcls in enumerate(current_classes):
# mAP threshold array: [num_minoverlap, metric, class]
# mAP result: [num_class, num_diff, num_minoverlap]
for i in range(min_overlaps.shape[0]):
"""if compute_aos:
result += print_str((f"aos AP:{mAPaos_R40[j, 0, i]:.2f}, "
f"{mAPaos_R40[j, 1, i]:.2f}, "
f"{mAPaos_R40[j, 2, i]:.2f}"))
if i == 0:
ret_dict['%s_aos_30m_R40' % class_to_name[curcls]] = mAPaos_R40[j, 0, 0]
ret_dict['%s_aos_50m_R40' % class_to_name[curcls]] = mAPaos_R40[j, 1, 0]
ret_dict['%s_aos_70m_R40' % class_to_name[curcls]] = mAPaos_R40[j, 2, 0]
"""
if i == 0:
for i in range(8):
ret_dict['%s_3d_%d-%dm' % (class_to_name[curcls], i*10, (i+1)*10)] = mAP3d[j, i, 0]
ret_dict['%s_3d_>80m' % class_to_name[curcls]] = mAP3d[j, 8, 0]
for i in range(8):
ret_dict['%s_bev_%d-%dm' % (class_to_name[curcls], i*10, (i+1)*10)] = mAPbev[j, i, 0]
ret_dict['%s_bev_>80m' % class_to_name[curcls]] = mAPbev[j, 8, 0]
for i in range(8):
ret_dict['%s_image_%d-%dm' % (class_to_name[curcls], i*10, (i+1)*10)] = mAPbbox[j, i, 0]
ret_dict['%s_image_>80m' % class_to_name[curcls]] = mAPbbox[j, 8, 0]
for i in range(8):
ret_dict['%s_3d_%d-%dm_R40' % (class_to_name[curcls], i*10, (i+1)*10)] = mAP3d_R40[j, i, 0]
ret_dict['%s_3d_>80m' % class_to_name[curcls]] = mAP3d_R40[j, 8, 0]
for i in range(8):
ret_dict['%s_bev_%d-%dm_R40' % (class_to_name[curcls], i*10, (i+1)*10)] = mAPbev_R40[j, i, 0]
ret_dict['%s_bev_>80m' % class_to_name[curcls]] = mAPbev_R40[j, 8, 0]
for i in range(8):
ret_dict['%s_image_%d-%dm_R40' % (class_to_name[curcls], i*10, (i+1)*10)] = mAPbbox_R40[j, i, 0]
ret_dict['%s_image_>80m' % class_to_name[curcls]] = mAPbbox_R40[j, 8, 0]
return result, ret_dict
这里可以更改class以及距离阈值,我们将他import到evaluate.py中,并调用:
# 调用
from eval import get_official_eval_result_by_distance, get_official_eval_result, get_coco_eval_result
# 评估
_, results = get_official_eval_result_by_distance(gt_annos, dt_annos,0)
(5)实现评估
最后我们使用指令:
python evaluate.py
--pred_infos result.pkl # 预测结果的位置
--gt_infos
./data/kitti/kitti_infos_val.pkl # KITTI验证集的pkl
--save_path
./distance/ # 结果保存路径
运行结果如下:
三、更改后的代码
我放到了我的开源GitHub下,供大家参考: