轨迹预测后处理之非极大值抑制(NMS)

非极大值抑制是图像处理里面的一种算法(比如边缘检测会使用到)

轨迹预测这里借鉴了其思想,比如说对于某个场景中的某辆车,我们使用模型预测 64 条轨迹或者更多,以很好地捕获多模态性,同时每条轨迹对应一个置信度,所有轨迹置信度总和为 1。但最终输出时,我们一般仅输出 6 条轨迹,如果直接选择置信度最高的 6 条轨迹会存在问题,比如说这六条轨迹靠的很近,无法体现多模态性。

这里随便举个例子,比如说我有 10 条轨迹,其置信度分别为 [ 0.1 , 0.3 , 0.2 , 0.5 , 0.6 , 0.4 , 0.7 , 0.9 , 0.8 , 1.0 ] [0.1, 0.3,0.2,0.5,0.6,0.4,0.7,0.9,0.8,1.0] [0.1,0.3,0.2,0.5,0.6,0.4,0.7,0.9,0.8,1.0](应该加和等于1,为了方便说明这里忽略)。首席按将轨迹按照置信度从高到低排序,即 [ 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 1.0 ] [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0] [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0],假设每条轨迹有 80 个 waypoint 点,我们计算每两条轨迹之间最后一个点的距离,会产生一个 10*10 大小的距离矩阵。

现在我们依次按照置信度高低选取轨迹,比如第一次选择排名第一的轨迹,后面再选择轨迹时需要跟已经选择的所有判断距离是否大于某个阈值,如果小于该阈值,说明存在已选的轨迹与当前要被选择的轨迹很类似,则放弃选择该轨迹。

一图胜千言
在这里插入图片描述
从图中6条轨迹中选择出3条,如果按照置信度来选,应该选择0.8,0.5,0.4的轨迹,但由于0.5和0.4两条轨迹靠的太近(小于某个阈值)因此最终选择的轨迹为0.8,0.5,0.3三条轨迹。

下面是MTR++中算法的实现方式。

def batch_nms(pred_trajs, pred_scores, dist_thresh, num_ret_modes=6):
    """

    Args:
        pred_trajs (batch_size, num_modes, num_timestamps, 7)
        pred_scores (batch_size, num_modes):
        dist_thresh (float):
        num_ret_modes (int, optional): Defaults to 6.

    Returns:
        ret_trajs (batch_size, num_ret_modes, num_timestamps, 5)
        ret_scores (batch_size, num_ret_modes)
        ret_idxs (batch_size, num_ret_modes)
    """
    batch_size, num_modes, num_timestamps, num_feat_dim = pred_trajs.shape

    sorted_idxs = pred_scores.argsort(dim=-1, descending=True)
    bs_idxs_full = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_modes)
    sorted_pred_scores = pred_scores[bs_idxs_full, sorted_idxs]  # 对score从大到小排序
    sorted_pred_trajs = pred_trajs[bs_idxs_full, sorted_idxs]  # (batch_size, num_modes, num_timestamps, 7)
    sorted_pred_goals = sorted_pred_trajs[:, :, -1, :]  # (batch_size, num_modes, 7)  最后一个点

    dist = (sorted_pred_goals[:, :, None, 0:2] - sorted_pred_goals[:, None, :, 0:2]).norm(dim=-1)  # 64*64 的距离矩阵
    point_cover_mask = (dist < dist_thresh)

    point_val = sorted_pred_scores.clone()  # (batch_size, N)
    point_val_selected = torch.zeros_like(point_val)  # (batch_size, N)

    ret_idxs = sorted_idxs.new_zeros(batch_size, num_ret_modes).long()
    ret_trajs = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes, num_timestamps, num_feat_dim)
    ret_scores = sorted_pred_trajs.new_zeros(batch_size, num_ret_modes)
    bs_idxs = torch.arange(batch_size).type_as(ret_idxs)

    for k in range(num_ret_modes):
        cur_idx = point_val.argmax(dim=-1) # (batch_size)
        ret_idxs[:, k] = cur_idx

        new_cover_mask = point_cover_mask[bs_idxs, cur_idx]  # (batch_size, N)
        point_val = point_val * (~new_cover_mask).float()  # (batch_size, N)
        point_val_selected[bs_idxs, cur_idx] = -1
        point_val += point_val_selected

        ret_trajs[:, k] = sorted_pred_trajs[bs_idxs, cur_idx]
        ret_scores[:, k] = sorted_pred_scores[bs_idxs, cur_idx]

    bs_idxs = torch.arange(batch_size).type_as(sorted_idxs)[:, None].repeat(1, num_ret_modes)

    ret_idxs = sorted_idxs[bs_idxs, ret_idxs]
    return ret_trajs, ret_scores, ret_idxs
  • 21
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值