轨迹预测损失函数计算

本文介绍了两种轨迹预测方法:一是直接使用L1或L2损失,二是利用高斯混合模型(GMM)建模并设计了选择最接近真实轨迹的损失函数。文章详细描述了如何计算负对数似然损失和交叉熵损失,以及在MTR++框架中的实现过程。
摘要由CSDN通过智能技术生成

第一种

直接使用预测的轨迹和真实轨迹做L1或者L2损失,这种没什么好说的

第二种

将轨迹建模为高斯混合模型,参见 MultiPath: Multiple Probabilistic Anchor Trajectory Hypotheses for Behavior Prediction
公式如下:
p ( s ∣ x ) = ∑ k = 1 K π ( a k ∣ x ) ∏ t = 1 T ϕ ( s t ∣ a k , x ) p(\mathrm{s} |\mathrm{x} )=\sum_{k=1}^{K} \mathrm{\pi} (\mathrm{a}^k |\mathrm{x})\prod_{t=1}^{T}\phi \left ( s_t| \mathrm{a}^k, \mathrm{x} \right ) p(sx)=k=1Kπ(akx)t=1Tϕ(stak,x)
其中 x \mathrm{x} x 表示 encode 信息, K K K 表示预测的轨迹数目, T T T 表示每条轨迹的长度(waypoint 点的个数),每个预测的 waypoint 都建模为一个二维高斯模型,即
在这里插入图片描述
并且假设每个 waypoint 点之间独立同分布,则损失函数可以设计为:在这里插入图片描述
表示选择预测轨迹里面与GT最接近的轨迹产生 Loss,其他轨迹不产生 Loss,计算方式如下,比如说预测B个车辆的轨迹,每个车辆预测 64 条轨迹,每条轨迹包含80个waypoint点,则模型输出包含两部分,一部分是 B*64 个置信度,即64条轨迹对应的概率,另一部分是高斯模型的参数 B*64*80*5,每个 waypoint 的参数为 μ x , μ y , σ x , σ y , ρ \mu_x, \mu_y, \sigma_x, \sigma_y, \rho μx,μy,σx,σy,ρ。下面我从百度百科把二维高斯分布的表达式搬过来方便对照计算。
在这里插入图片描述
在计算时,我们不完全根据上面的损失函数表达式计算,而是采用

  • 置信度预测使用交叉熵损失函数计算
  • 二维高斯模型预测使用负对数似然计算
gt_traj = input_dict['gt_traj']  # (B, 80, 2)
pred_score, pred_trajs = model(input_decode)
# pred_score (B, 64)
# pred_trajs (B, 64, 80, 5)

# 选择与 gt 最接近的预测轨迹
dist = (gt_traj[:,None,:,:] - pred_trajs[:,:,:,:2).norm(dim=-1).sum(dim=-1)  # (B, 64, 80, 2)->(B, 64, 80)->(B, 64)
nearest_mode_idxs = dist.argmin(dim=-1)  # (B, 1)
nearest_mode_bs_idxs = torch.arange(pred_score.shape[0]).type_as(nearest_mode_idxs)
nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs]  # (B, 80, 5)

# ---------- 计算负对数似然损失 --------- #
res_trajs = gt_trajs - nearest_trajs[:, :, 0:2]  # (B, 80, 2)
dx = res_trajs[:, :, 0]  # x-\mu_x
dy = res_trajs[:, :, 1]  # y-\mu_y
log_std1 = torch.clip(nearest_trajs[:, :, 2], min=-1.609, max=5.0)  # log(\sigma_x)
log_std2 = torch.clip(nearest_trajs[:, :, 3], min=-1.609, max=5.0)  # log(\sigma_y)
std1 = torch.exp(log_std1)  # \sigma_x
std2 = torch.exp(log_std2)  # \sigma_y
rho = torch.clip(nearest_trajs[:, :, 4], min=-0.5, max=0.5)  # \rho

reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2)  # 系数部分
reg_gmm_exp = (0.5*1/(1-rho**2)) * ((dx**2)/(std1**2) + (dy**2)/(std2**2) - 2*rho*dx*dy/(std1*std2))  # 指数部分

reg_loss = (reg_gmm_log_coefficient + reg_gmm_exp).sum(dim=-1)

# ---------- 计算交叉熵损失 --------- #
loss_cls = F.cross_entropy(input=pred_scores, target=nearest_mode_bs_idxs, reduction='none')

上面代码参考自 MTR++。

  • 22
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值