DenseTNT比起TNT的优点是不用前处理来获得anchor,也不需要后处理nms来获得轨迹。像TNT这种前处理获得到的anchor质量不能保证,nms这种后处理本质上是一种贪心算法,不一定是最优解的。
由于预测的目标只有一个,去做多目标真值有一定的困难,这篇文章还提出了一种离线的优化算法,能够得到一些pseudo-label。(个人认为这是这篇文章最大的亮点,离线和在线相结合的训练方法)
Sparse context encoding
对地图、障碍物的历史信息等进行encoding,参考了vectornet的方式,一开始用subgraph图网络再用attention。最后输出二维矩阵L,每一行代表了一个环境元素(地图或者agent)
Dense goal probability estimation
分为两部分,lane scoring和probability estimation。
通过lane scoring来减少求解空间,lane scoring还是有真值的,通过cross-entropy loss来计算
F是初始化设定的goal matrix,L是上一步的context encoding结果。先进行Cross-attention,得到结果A, 这个结果会传到后面的goal set predictor,作为heat map。同时会对F进行loss计算,这里会有goal的真值。
Goal set prediction
为了解决预测没有多个真值的问题,会使用offline的优化算法来优化,具体的,前边获得heatmap的方法和online是一样的,将goal set prediction这一块换成optimization algorithm。
然后再用optimization algorithm的结果,去指导离线部分。
由于这里没有多个真值,所以用之前得到的heatmap的goal来作为真值,同时也考虑heatmap的概率。然后使用迭代的方式来最小化distantce。这个迭代的方式有点像EM算法
在线部分参考了DETR这一篇目标检测的文章,将这个问题看作set prediction problem。利用offline得到的pseudo labels来作为真值来计算loss。由于之前已经经历了100次random perturbation,所以这个pseudo label还是非常接近于gt的。这个地方有两部分loss:一个是goal和pseudo label之间的loss,还有一种是多个预测头之间置信度的loss。
goal set prediction的模型结构如上文所示,用例attention + max_pooling + MLP的方式。attention和max_pooling这个结构大家公用,但是MLP 多个头是分开的,最后会预测出K的goal点和每个头的置信度。
这个地方置信度的真值标签v也很有意思,是在训练的时候现标注的,真值是预测出的轨迹里真值最近的那个head。
最后会取置信度最高的那个head输出的轨迹作为结果。
Trajectory completion
最后使用mlp来将goal延展出轨迹,并且和真值轨迹做一个loss
Learning
这个地方的learning也非常有意思,有确定性真值的先进行learning,没有确定性真值的goal prediction goal那部分再进行learning。