[文献阅读笔记]:SEPT: TOWARDS EFFICIENT SCENE REPRESENTATION LEARNING FOR MOTION PREDICTION
文章目录
文章地址:https://arxiv.org/pdf/2309.15289.pdf
1. 概述
SEPT所做的工作和Forecast-MAE总体很像,同样都是采用了自监督学习的方式对模型进行预训练,在此基础上进行fine-tune来预测轨迹输出,感兴趣的小伙伴可以看一下之前的这篇文章,
1.1 模型解决问题的方向
- 训练方法:自监督学习+监督学习,通过设计不同的三个子学习任务来对模型的结构进行预训练。
- 网络结构设计:两个不同的transformer模块来依次提取时序和空间交互特征,使用Learnable query与交互特征做cross-attention,输出最后的预测结果。
1.2 主要结论和贡献
-
作者设计了三种不同的自监督学习方法,分别是
- *Marked Trajectory Modeling (MTM),*随机mask掉轨迹中的部分点,通过预训练任务,旨在预训练时序特征提取模块,使模块可以更有效的建模时序特征。
- Masked Road Modeling (MRM),随机mask掉输入道路特征的部分点,旨在训练空间特征提取模块。
- Tail Prediction (TP),将轨迹分为头部和尾部,旨在通过前半部分的轨迹特征,预测后半部分的轨迹特征,算是简化版的轨迹预测。
SEPT与Forecast在SSL应用上的区别和联系:
相同点:
二者都是通过SSL,对模型的结构进行预训练,旨在获得模型在时序交互、空间交互上特征提取的能力。
不同点:
对于轨迹的mask:fmae,mask掉一部分比例的历史或者未来轨迹;sept则是mask掉历史轨迹中的部分轨迹点
对于轨迹的预测:fmae,是预测被mask掉的历史或者未来轨迹;sept则是预测mask掉的轨迹点
预训练任务:fmae,将车道与轨迹的重建,放到一个预训练任务中,通过一个任务,完成车道线与轨迹 mask部分的重建;
sept,则是分别预测轨迹、车道获取模型对时序以及空间建模的能力以及通过TP任务完成时空特征的交互。
2. 模型
2.1 模型架构
Input
轨迹:历史轨迹 [ A , T , D h ] [A,T,D_h] [A,T,Dh],其中A为周围障碍物的数量,T为时间序列的长度,特征包括轨迹点坐标,时间戳,类型和其他数据集属性
车道:车道线 [ R , D r ] [R,D_r] [R,Dr],特征包括车道起始点位置,车道长度,车道转向方向和其他数据集属性。另外SEPT还使用了purning model的方法,减少车道线的数量,降低计算量。
Projection
将不同输入维度的特征映射到固定的表征空间
R
D
\mathbb{R}^D
RD,
D
=
256
D=256
D=256是表征空间的维度。
P
r
o
j
e
c
t
(
x
)
=
m
a
x
(
W
x
+
b
,
0
)
\mathbf{Project}(x)=max(Wx+b,0)
Project(x)=max(Wx+b,0)
TempoNet
由3个堆叠的Transformer blocker组成,输入porject层的输出,维度为
[
A
,
T
,
D
]
[A,T,D]
[A,T,D],沿着T维度做self-attention,对T维度的输入做相对位置嵌入(T5,T5_coding)来编码时间序列的相对位置关系。
s
o
f
t
m
a
x
(
Q
×
K
T
+
p
o
s
i
t
i
o
n
b
i
a
s
)
×
V
softmax(Q\times K^T+position_{bias})\times V
softmax(Q×KT+positionbias)×V
输出经过max-pooling,获得时序特征,
[
A
,
D
]
[A,D]
[A,D]。
SpaNet
由2个堆叠的transformer blocker组成,输入轨迹和车道信息, [ A + S , D ] [A+S,D] [A+S,D],输出轨迹与车道之间的交互特征。
CrossAttender
由3个交叉注意力层组成,与以往轨迹与车道或者轨迹与轨迹之间的交叉注意力不同,本文作者使用一组可学习的query( [ N , D ] , N [N,D],N [N,D],N为预测的轨迹模式)去查询经过时空编码的特征,输出维度为 [ N , D ] [N,D] [N,D]的注意力特征,并经过两个MLP,输出轨迹和概率值。
2.2 模型Pre-Training
- MTM学习任务,主要去训练TempoNet学习时间序列的依赖。
- MRM学习任务,训练SpaNet学习对于道路时空特征的依赖。
- TP学习任务,通过头部轨迹特征预测尾部轨迹,整合时空特征依赖。
2.3 Fine-tune
训练完整的网络即可。
2.4 Loss function
回归损失+分类损失:
3. 实验
TP子学习任务的添加,由着更低的训练方差。
4. 思考和比较
为什么forecast-mae的指标要比sept稍逊一筹?
-
输入上:
fmae输入的是基于上一帧的位移;sept输入的是相对于预测目标的位置信息
-
结构上:
fmae使用NAT做时间序列编码;sept则是采用标准的transformer结构,另外在提取到时空特征后,sept还使用了可学习的query做交叉注意力特征提取,进而,提取到轨迹的高维表示,进而解码输出轨迹特征。
-
预训练方法上:
sept采用的是分部训练+联合训练的方法,分别训练各个模块;fmae则是只使用联合训练的方法。
如何将sept应用到多目标的预测?
1. 多目标表示:时间序列建模时,使用每个预测障碍物自身坐标系下的输入特征,空间序列建模时,进一步添加车辆与轨迹的时空位置嵌入,建模时空下的特征交互
2. 多目标预测:可学习的轨迹query从 [ N , D ] [N,D] [N,D]扩充到 [ A , N , D ] [A,N,D] [A,N,D]。
3. 训练策略:Loss function的更改
4. 评估指标:使用multi-agent预测的评估指标。