题目 | Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting |
---|---|
源码地址 | https://github.com/XDZhelheim/STAEformer |
关键词 | 交通预测,时空嵌入,Transformer |
发表会议 | CIKM2023 |
摘要
随着智能交通系统(ITS)的快速发展,准确的交通预测已成为一个关键挑战。主要瓶颈在于捕捉复杂的时空交通模式。近年来,提出了许多具有复杂架构的神经网络来解决这一问题。然而,网络架构的进步遇到了性能提升的瓶颈。在本研究中,我们提出了一种称为时空自适应嵌入的新组件,该组件在使用标准Transformer的情况下仍能产生卓越的结果。我们提出的时空自适应嵌入Transformer(STAEformer)在五个真实世界的交通预测数据集上达到了最先进的性能。进一步的实验表明,时空自适应嵌入在交通预测中发挥了关键作用,有效捕捉了交通时间序列中内在的时空关系和时间信息。
CCS概念
- 信息系统 → 时空系统;
- 计算方法 → 人工智能。
1 介绍
交通预测旨在基于历史观测预测道路网络中的未来交通时间序列。近年来,深度学习模型的成功尤为显著,主要归因于其能够捕捉交通系统中固有的时空依赖性。其中,时空图神经网络 (STGNNs) 和基于Transformer的模型因其出色的性能变得非常流行。研究人员花费了大量精力开发复杂的交通预测模型,例如新的图卷积、学习图结构、高效的注意力机制,以及其他方法。然而,网络架构的进展遇到了性能提升的瓶颈,促使人们从复杂的模型设计转向更有效的表征技术。
基于此,在本研究中,我们聚焦于输入嵌入,一种广泛使用的简单但强大的表征技术,许多研究人员往往在其有效性方面忽略了它。具体来说,它在输入上添加了一个嵌入层,为模型骨干提供了多种类型的嵌入。图1展示了之前模型中所采用嵌入的对比分析。STGNNs主要使用特征嵌入 E f E_f Ef,即一种转换将原始输入投射到隐藏空间。基于Transformer的模型需要额外的知识,如时间位置编码 E t p e E_{tpe} Etpe 和周期性(日常、每周、每月)嵌入 E p E_p Ep,这是由于注意力机制无法保留时间序列的位置信息。最近的模型,包括PDFOrmer、GMAN 和 STID,都应用了空间嵌入 E s E_s Es。值得注意的是,STID 是少数研究这些嵌入的研究之一。它采用了空间嵌入和时间周期嵌入,并结合了简单的多层感知机 (MLP),取得了显著的性能。
为了进一步增强表征的有效性,我们提出了一种新的时空自适应嵌入 E a E_a Ea,并将其与 E p E_p Ep 和 E f E_f Ef 一起应用于vanilla Transformer,如图1d所示。具体来说,原始输入通过嵌入层获取输入嵌入,输入嵌入被传递给时空Transformer层,然后经过回归层以做出预测。我们提出的模型命名为时空自适应嵌入Transformer (STAEformer),其架构更加简洁,但在性能上达到了最新的SOTA(state-of-the-art)。在我们的模型中, E a E_a Ea 通过有效捕捉交通时间序列中的内在时空关系和时间信息,发挥了关键作用。实验和对五个真实交通数据集的分析证明,我们提出的 E a E_a Ea 能够使vanilla Transformer在交通预测中达到SOTA水平。
2 问题定义
给定过去 T T T 个时间帧的交通时间序列 X t − T + 1 : t X_{t-T+1:t} Xt−T+1:t,交通预测旨在推断未来 T ′ T' T′ 个时间帧的交通数据,公式如下:
[ X t − T + 1 , … , X t ] → [ X t + 1 , … , X t + T ′ ] , [X_{t-T+1}, \dots, X_t] \rightarrow [X_{t+1}, \dots, X_{t+T'}], [Xt−T+1,…,Xt]→[Xt+1,…,Xt+T′],
其中每个帧 X i ∈ R N × d X_i \in \mathbb{R}^{N \times d} Xi∈RN×d, N N N 是空间节点的数量, d = 1 d=1 d=1 是交通量的输入维度。
3 方法
如图2所示,我们的模型由一个嵌入层、沿时间轴应用的vanilla transformer作为时间Transformer层、沿空间轴应用的空间Transformer层和回归层组成。
3.1 嵌入层
为了保持原始数据中的固有信息,我们利用全连接层来获得特征嵌入 E f ∈ R T × N × d f E_f \in \mathbb{R}^{T \times N \times d_f} Ef∈RT×N×df:
E f = F C ( X t − T + 1 : t ) E_f = FC(X_{t-T+1:t}) Ef=FC(Xt−T+1:t)
其中 d f d_f df 是特征嵌入的维度, F C ( ⋅ ) FC(\cdot) FC(⋅) 表示全连接层。
我们将可学习的星期几嵌入词典表示为 T w ∈ R N w × d f T_w \in \mathbb{R}^{N_w \times d_f} Tw∈RNw×df,时间戳嵌入词典表示为 T d ∈ R N d × d f T_d \in \mathbb{R}^{N_d \times d_f} Td∈RNd×df,其中 N w = 7 N_w=7 Nw=7 表示每周的天数, N d = 288 N_d=288 Nd=288 表示每天的时间戳数量。我们将 D t ∈ R T D^t \in \mathbb{R}^{T} Dt∈RT 作为星期几数据,时间戳数据为 t − T + 1 : t t-T+1:t t−T+1:t 中的交通时间序列,使用它们作为索引来提取对应的星期几嵌入 E w t ∈ R T × d f E_w^t \in \mathbb{R}^{T \times d_f} Ewt∈RT×df 和时间戳嵌入 E d t ∈ R T × d f E_d^t \in \mathbb{R}^{T \times d_f} Edt∈RT×df。通过拼接和广播,我们获得了交通时间序列的周期嵌入 E p ∈ R T × N × 2 d f E_p \in \mathbb{R}^{T \times N \times 2d_f} Ep∈RT×N×2df。
从直觉上讲,交通时间序列中的时间帧应与相邻时间帧更加相似。另一方面,不同传感器的时间序列可能具有不同的时间模式。因此,我们设计了一个时空自适应嵌入 E a ∈ R T × N × d a E_a \in \mathbb{R}^{T \times N \times d_a} Ea∈RT×N×da,以一种统一的方式捕捉复杂的时空关系。特别地, E a E_a Ea 在不同的交通时间序列之间共享。
通过拼接上述嵌入,我们得到了隐藏的时空表征 Z ∈ R T × N × d h Z \in \mathbb{R}^{T \times N \times d_h} Z∈RT×N×dh,公式如下:
Z = E f ∣ ∣ E p ∣ ∣ E a Z = E_f || E_p || E_a Z=Ef∣∣Ep∣∣Ea
其中隐藏维度 d h = 3 d f + d a d_h = 3d_f + d_a dh=3df+da。
3.2 Transformer与回归层
我们沿时间和空间轴应用vanilla transformer来捕捉复杂的交通关系。给定 T T T 帧和 N N N 空间节点的隐藏时空表征 Z ∈ R T × N × d h Z \in \mathbb{R}^{T \times N \times d_h} Z∈RT×N×dh,我们通过时间Transformer层获得查询、键和值矩阵:
Q ( t e ) = Z W Q ( t e ) , K ( t e ) = Z W K ( t e ) , V ( t e ) = Z W V ( t e ) Q^{(te)} = ZW_Q^{(te)}, \quad K^{(te)} = ZW_K^{(te)}, \quad V^{(te)} = ZW_V^{(te)} Q(te)=ZWQ(te),K(te)=ZWK(te),V(te)=ZWV(te)
其中 W Q ( t e ) , W K ( t e ) , W V ( t e ) ∈ R d h × d h W_Q^{(te)}, W_K^{(te)}, W_V^{(te)} \in \mathbb{R}^{d_h \times d_h} WQ(te),WK(te),WV(te)∈Rdh×dh 是可学习参数。然后我们计算自注意力得分:
A ( t e ) = Softmax ( Q ( t e ) K ( t e ) ⊤ d h ) A^{(te)} = \text{Softmax} \left( \frac{Q^{(te)} {K^{(te)}}^\top}{\sqrt{d_h}} \right) A(te)=Softmax(dhQ(te)K(te)⊤)
A ( t e ) ∈ R N × T × T A^{(te)} \in \mathbb{R}^{N \times T \times T} A(te)∈RN×T×T 捕捉不同空间节点中的时间关系。最终,我们得到时间Transformer层的输出 Z ( t e ) ∈ R T × N × d h Z^{(te)} \in \mathbb{R}^{T \times N \times d_h} Z(te)∈RT×N×dh,公式如下:
Z ( t e ) = A ( t e ) V ( t e ) Z^{(te)} = A^{(te)} V^{(te)} Z(te)=A(te)V(te)
类似地,空间Transformer层的输出为:
Z ( s p ) = SelfAttention ( Z ( t e ) ) Z^{(sp)} = \text{SelfAttention}(Z^{(te)}) Z(sp)=SelfAttention(Z(te))
其中自注意力机制遵循之前的公式。我们还应用了层归一化、残差连接和多头机制。
最后,我们利用时空Transformer层的输出 Z ′ ∈ R T ′ × N × d h Z' \in \mathbb{R}^{T' \times N \times d_h} Z′∈RT′×N×dh 来生成预测。回归层的公式如下:
Y ^ = F C ( Z ′ ) \hat{Y} = FC(Z') Y^=FC(Z′)
其中 Y ^ ∈ R T ′ × N × d \hat{Y} \in \mathbb{R}^{T' \times N \times d} Y^∈RT′×N×d 是预测, T ′ T' T′ 是预测的时间范围, d d d 是输出特征的维度,在我们的例子中 d = 1 d=1 d=1。因此,全连接层将 Z ′ Z' Z′ 中的维度从 T × d h T \times d_h T×dh 回归到 T ′ × ( d = 1 ) T' \times (d=1) T′×(d=1)。
4 实验
4.1 实验设置
设置。 我们的方法在五个交通预测基准上进行了验证(METR-LA、PEMS-BAY、PEMS04、PEMS07 和 PEMS08)。这五个数据集中的交通数据被聚合为5分钟间隔。因此,每小时有12个时间帧。更多详细信息如表2所示。
METR-LA 和 PEMS-BAY 被划分为训练集、验证集和测试集,比例为7:1:2。PEMS04、PEMS07 和 PEMS08 则按6:2:2的比例划分。事实上,我们模型的性能对超参数不敏感。更详细的信息为:特征嵌入维度 d f d_f df 为24,自适应嵌入维度 d a d_a da 为80。空间和时间Transformer的层数 L L L 为3,头数为4。我们将输入和预测长度设置为1小时,即 T = T ′ = 12 T=T'=12 T=T′=12。Adam被选为优化器,学习率从0.001开始衰减,批量大小为16。如果验证误差在连续30步内收敛,我们将采用提前终止机制。我们使用三个广泛用于交通预测任务的指标,即MAE、RMSE 和 MAPE。根据以往的工作,我们选择PEMS04、PEMS07 和 PEMS08数据集中预测12个时间范围的平均表现。为了评估METR-LA 和 PEMS-BAY数据集,我们比较了在3、6 和12步(分别为15、30 和60分钟)上的表现。
基线模型。 在本研究中,我们将提出的方法与领域中几种广泛使用的基线模型进行了比较。HI 是一个典型的传统模型。我们还考虑了STGNNs模型,例如GWNet、DCRNN、AGCRN、STGCN、GTS 和 MTGNN,它们使用了图1(a)中展示的嵌入。此外,我们考察了STNorm模型,它专注于分解交通时间序列。尽管存在基于Transformer的方法,如Informer、Pyraformer、FEDformer 和 Autoformer,但它们并未专门为短期交通预测设计。因此,我们选择了与我们任务相同的Transformer模型GMAN 和 PDFOrmer。文献 [42] 和 [14] 中的输入嵌入配置如图1(b)所示。此外,我们考虑了STID,该模型通过利用图1©中的输入嵌入增强了交通时间序列中的时空差异。
4.2 性能评估
如表1和表3所示,我们的方法在所有五个数据集的绝大多数指标上取得了更好的性能。STAEformer 在没有任何图结构建模的情况下大幅优于STGNNs。STNorm 和 STID 也获得了具有竞争力的结果,而基于Transformer的模型则能更好地捕捉复杂的时空关系。与PDFOrmer相比,令人鼓舞的结果表明STAEformer是一个更简单但更有效的解决方案。
4.3 消融实验
为了评估STAEformer中每个部分的有效性,我们进行了消融实验,使用了我们模型的四个变体,如下所示:
- 去掉 E a E_a Ea:移除时空自适应嵌入 E a E_a Ea。
- 去掉 E p E_p Ep:移除周期嵌入 E p E_p Ep,包括星期几和一天中的时间戳嵌入。
- 去掉时间Transformer层:移除时间Transformer层。
- 去掉时空Transformer层:移除时间和空间Transformer层。
表4揭示了各种嵌入对模型性能的重要性。 E p E_p Ep 可以捕捉到日常和每周的模式,而提出的 E a E_a Ea 对交通建模至关重要。此外,在去除空间或时间Transformer层时,我们观察到性能大幅下降,这表明我们的嵌入模型能够有效建模交通数据中的固有时空模式。因此,空间和时间层对于提取这些特征是必要的。
4.4 案例研究
与空间嵌入的比较。 为了验证时空自适应嵌入在捕捉输入帧 T T T 中隐含的时间顺序信息方面的有效性,我们在PEMS04和PEMS08数据集上进行了更多实验。我们随机打乱了沿时间轴 T T T 的原始输入。为了比较,我们将时空自适应嵌入 E a E_a Ea 替换为文献中使用的空间嵌入 E s E_s Es。如图3所示,当打乱输入时,我们的模型表现出更严重的性能下降。这意味着时空自适应嵌入 E a E_a Ea 使我们的模型对时间顺序更加敏感,而使用 E s E_s Es 的模型对顺序的敏感性较低。总结来说, E a E_a Ea 能够更好地建模输入中的时间顺序信息以及其他复杂的交通模式,这对于任务来说至关重要。
时空自适应嵌入的可视化。 图4进一步提供了我们提出的时空自适应嵌入 E a E_a Ea 在空间和时间轴上的可视化,以PEMS08数据集为例。在空间轴上,我们使用t-SNE获得图4a。结果显示,不同节点的嵌入自然地聚成簇,匹配了交通数据的空间特征。在时间轴上,我们计算了12个输入帧中各帧之间的相关系数,并绘制了图4b中的热力图。结果显示,每个时间帧与相邻帧的相关性较高,而对于更远的帧,相关性逐渐降低。结果表明,我们提出的 E a E_a Ea 能够正确建模时间序列中的时间信息。
5 结论
在本研究中,我们专注于一种用于交通时间序列预测的基础表征学习技术,即输入嵌入。我们提出了一种新的时空自适应嵌入,能够在vanilla transformers上实现交通预测基准测试中的最先进(SOTA)性能。进一步的研究表明,我们的模型可以有效捕捉交通中的内在时空依赖性。与设计复杂模型不同,我们的研究展示了一条解决交通预测挑战的有前途的方向。