第一种
直接使用预测的轨迹和真实轨迹做L1或者L2损失,这种没什么好说的
第二种
将轨迹建模为高斯混合模型,参见 MultiPath: Multiple Probabilistic Anchor Trajectory Hypotheses for Behavior Prediction
公式如下:
p
(
s
∣
x
)
=
∑
k
=
1
K
π
(
a
k
∣
x
)
∏
t
=
1
T
ϕ
(
s
t
∣
a
k
,
x
)
p(\mathrm{s} |\mathrm{x} )=\sum_{k=1}^{K} \mathrm{\pi} (\mathrm{a}^k |\mathrm{x})\prod_{t=1}^{T}\phi \left ( s_t| \mathrm{a}^k, \mathrm{x} \right )
p(s∣x)=k=1∑Kπ(ak∣x)t=1∏Tϕ(st∣ak,x)
其中
x
\mathrm{x}
x 表示 encode 信息,
K
K
K 表示预测的轨迹数目,
T
T
T 表示每条轨迹的长度(waypoint 点的个数),每个预测的 waypoint 都建模为一个二维高斯模型,即
并且假设每个 waypoint 点之间独立同分布,则损失函数可以设计为:
表示选择预测轨迹里面与GT最接近的轨迹产生 Loss,其他轨迹不产生 Loss,计算方式如下,比如说预测B个车辆的轨迹,每个车辆预测 64 条轨迹,每条轨迹包含80个waypoint点,则模型输出包含两部分,一部分是 B*64
个置信度,即64条轨迹对应的概率,另一部分是高斯模型的参数 B*64*80*5
,每个 waypoint 的参数为
μ
x
,
μ
y
,
σ
x
,
σ
y
,
ρ
\mu_x, \mu_y, \sigma_x, \sigma_y, \rho
μx,μy,σx,σy,ρ。下面我从百度百科把二维高斯分布的表达式搬过来方便对照计算。
在计算时,我们不完全根据上面的损失函数表达式计算,而是采用
- 置信度预测使用交叉熵损失函数计算
- 二维高斯模型预测使用负对数似然计算
gt_traj = input_dict['gt_traj'] # (B, 80, 2)
pred_score, pred_trajs = model(input_decode)
# pred_score (B, 64)
# pred_trajs (B, 64, 80, 5)
# 选择与 gt 最接近的预测轨迹
dist = (gt_traj[:,None,:,:] - pred_trajs[:,:,:,:2).norm(dim=-1).sum(dim=-1) # (B, 64, 80, 2)->(B, 64, 80)->(B, 64)
nearest_mode_idxs = dist.argmin(dim=-1) # (B, 1)
nearest_mode_bs_idxs = torch.arange(pred_score.shape[0]).type_as(nearest_mode_idxs)
nearest_trajs = pred_trajs[nearest_mode_bs_idxs, nearest_mode_idxs] # (B, 80, 5)
# ---------- 计算负对数似然损失 --------- #
res_trajs = gt_trajs - nearest_trajs[:, :, 0:2] # (B, 80, 2)
dx = res_trajs[:, :, 0] # x-\mu_x
dy = res_trajs[:, :, 1] # y-\mu_y
log_std1 = torch.clip(nearest_trajs[:, :, 2], min=-1.609, max=5.0) # log(\sigma_x)
log_std2 = torch.clip(nearest_trajs[:, :, 3], min=-1.609, max=5.0) # log(\sigma_y)
std1 = torch.exp(log_std1) # \sigma_x
std2 = torch.exp(log_std2) # \sigma_y
rho = torch.clip(nearest_trajs[:, :, 4], min=-0.5, max=0.5) # \rho
reg_gmm_log_coefficient = log_std1 + log_std2 + 0.5 * torch.log(1 - rho**2) # 系数部分
reg_gmm_exp = (0.5*1/(1-rho**2)) * ((dx**2)/(std1**2) + (dy**2)/(std2**2) - 2*rho*dx*dy/(std1*std2)) # 指数部分
reg_loss = (reg_gmm_log_coefficient + reg_gmm_exp).sum(dim=-1)
# ---------- 计算交叉熵损失 --------- #
loss_cls = F.cross_entropy(input=pred_scores, target=nearest_mode_bs_idxs, reduction='none')
上面代码参考自 MTR++。