Hyper-STTN: Social Group-aware Spatial-Temporal Transformer Network for Human Trajectory Prediction
Hyper-STTN 中时空图概念体现在多个部分。空间-时间Transformer 网络利用类似时空图的思想,通过注意力机制在时空维度上对行人运动相关性进行抽象。空间 Transformer 在时间维度上运用图卷积网络,结合多头注意力层来表示空间注意力特征和空间关系特征;时间 Transformer 则利用多头注意力层捕捉每个个体的长期时间注意力依赖。同时,Hyper-STTN通过构建多尺度超图来进一步细化群体层面的时空依赖关系,超图中的顶点和超边不仅考虑了空间位置关系,还融入了时间因素,综合这些信息来推断行人轨迹。其网络结构如下图所示:

1.时空图
时空图的构建主要包括两个部分:空间图和时间图。这些图用于捕捉行人之间的成对交互和群体交互。
如下图所示,展示了构建时空图后,群组之间在时间维度和空间维度的交互影响。
1.1空间图(Spatial Graph)
- 节点(Nodes):每个节点代表一个行人,节点的特征是行人的位置信息(如坐标)。
- 边(Edges):边表示行人之间的空间关系,通常通过计算行人之间的距离或相似性来确定。如果两个行人的距离小于某个阈值,则在它们之间建立一条边。
- 空间交互(Spatial Interaction):通过空间图捕捉行人之间的即时交互,例如,行人之间的相对位置和运动方向。
1.2时间图(Temporal Graph)
- 节点(Nodes):每个节点代表一个行人在某个确定时间步的位置。
- 边(Edges):边表示同一行人在不同时间步之间的连接,用于捕捉行人的历史运动模式。
- 时间交互(Temporal Interaction):通过时间图捕捉行人的历史轨迹信息,例如,行人的速度和加速度。
2.超图
所谓超图(Hypergraph)是图的一种推广形式,用于表示更复杂的多对多关系。在传统图中,每条边只能连接两个节点(Node),而在超图中,超边可以连接多个节点,从而能够更灵活地表示高阶关系。
在该篇论文中被用来表示连接视野中人的联系。对比于普通的图关系来说超图可以表示多个被抽象为节点的人之间的关系如下图所示:
这种关系被类似与邻接矩阵的表示二位矩阵表示,经过超图连接关系生成后,可以作为神经网络的输入。
2.1超图(Hypergraph)的构建
- 使用不同尺度的KNN(K-Nearest Neighbors)算法在特征空间中构建多尺度超图。每个超图对应不同群体大小下的交互关系。
- 多尺度超图:通过构建多尺度超图来捕捉不同群体大小下的高阶交互关系,同时捕捉局部和全局的交互关系。在行人轨迹预测中,多尺度超图可以用于建模不同时间尺度上的行人交互,从而更准确地预测未来轨迹。
2.2超图谱卷积操作:
- 通过随机游走概率的超图谱卷积操作,将群体间的交互信息从点到边(point-to-edge)和从边到点(edge-to-point)进行聚合。
- 超图谱卷积:利用基于随机游走概率的超图谱卷积操作,将群体间的交互信息从点到边和从边到点进行聚合。超图谱卷积能够捕捉比传统图卷积更复杂的高阶关系,同时结合本文所提出的时空Transformer机制进一步增强了模型的表述能力。
3.时空Transformer
时空Transformer网络用于捕捉行人之间的成对交互,为抽象行人运动在时空维度的相关性,设计时空 Transformer 网络。输入数据先经位置编码(PE),再依次通过层归一化(LN)、多头空间注意力层和前馈网络(FFN),并借助残差连接稳定网络。
3.1空间Transformer(Spatial Transformer)
- 输入数据首先通过位置编码(Positional Encoding)层进行序列信息编码。
- 通过多头空间注意力层(Multi-head Spatial Attention Layer)和前馈网络(Feed-Forward Network, FFN)捕捉行人之间的空间关系。
3.1时间Transformer(Temporal Transformer)
- 通过多头注意力机制捕捉行人的历史运动模式。
- 使用掩码注意力机制(Mask Attention Mechanism)处理长度变化的序列数据。
4.多模态融合Transformer
Hyper-STTN通过超图卷积网络(HGNN)和时空Transformer网络分别捕捉群体交互和成对交互的特征。然而,这两种交互产生的特征是异构的,直接融合可能会导致信息混淆。因此,Hyper-STTN引入了多模态Transformer网络,用于对齐和融合这些异构特征。
综合来看,该机制的作用包括:
- 对齐异构特征:将群体交互和成对交互的特征对齐到同一特征空间,避免直接拼接操作可能忽略的跨模态关系。
- 捕捉跨模态注意力:通过跨模态注意力机制(cross-modal attention mechanism),突出不同模态特征之间的有意义关联。
- 增强特征融合:通过Transformer的多头注意力机制,模型能够更好地捕捉群体和成对交互之间的复杂关系。
特征融合的步骤如下:
- 1.特征提取:
群体交互特征:通过超图卷积网络(HGNN)提取,捕捉群体间的高阶交互。
成对交互特征:通过时空Transormer网络提取,捕捉个体之间的成对交互。 - 2.特征对齐:
使用多模态Transformer网络将两种特征对齐到同一特征空间。 - 3.跨模态注意力:
通过跨模态注意力机制,模型能够动态调整不同模态特征的权重,突出重要特征。 - 4.特征融合:
融合后的特征通过CVAE解码器生成未来的轨迹预测。
5.实验分析
实验在以下数据集上进行:
ETH-UCY数据集:包含ETH和UCY场景,是行人轨迹预测的常用基准数据集。
NBA数据集:包含篮球比赛中的球员轨迹。
SDD数据集:斯坦福无人机数据集,包含校园场景中的行人轨迹。
- 在ETH和UCY数据集上的表现:在ETH数据集上,Hyper-STTN在随机性条件下将ADE20(平均位移误差)从EqMotion的0.40降低到0.35,FDE20(最终位移误差)从0.61降低到0.59,分别提升了12.5%和3.2%。
- 在NBA数据集上,Hyper-STTN在多个场景中平均提升了14.7%、6.5%、3.4%和10.6%的ADE20性能,以及22.9%、14.7%、7.6%和10.1%的FDE20性能。
消融实验
消融实验验证了Hyper-STTN中各个模块的有效性:
1.超图神经网络(HGNN):
HGNN在多个数据集上的表现与GroupNet相当,证明了超图结构在群体交互建模中的有效性。
在行人数量较多的场景中,HGNN能够显著提升模型性能。
2.时空Transformer(STTN):
STTN通过引入掩码注意力和跨模态注意力机制,显著提升了轨迹预测的精度。
在行人较少的场景中,STTN的表现优于HGNN。
3.多模态Transformer:
多模态Transformer通过融合群体交互和成对交互的异构特征,进一步提升了模型性能。
直接拼接特征的消融模型(Hyper-STTN+MLP)在所有数据集上的表现均劣于完整的Hyper-STTN,证明了多模态Transformer在特征融合中的重要性。
总结
Hyper-STTN通过结合超图和时空Transformer,有效地捕捉了行人之间的成对交互和群体交互,显著提升了行人轨迹预测的性能。未来的工作可以进一步探索该模型在复杂场景中的应用,探索模型的轻量化设计和多模态数据融合,以提升模型的实时性和鲁棒性,并尝试将其扩展到实时检测中。