【论文翻译】CIKM 2023 | STAEformer:时空自适应嵌入使基础Transformer在交通预测中达到最先进性能 (SOTA)

image-20241011205151729

题目Spatio-Temporal Adaptive Embedding Makes Vanilla Transformer SOTA for Traffic Forecasting
源码地址https://github.com/XDZhelheim/STAEformer
关键词交通预测,时空嵌入,Transformer
发表会议CIKM2023

摘要

随着智能交通系统(ITS)的快速发展,准确的交通预测已成为一个关键挑战。主要瓶颈在于捕捉复杂的时空交通模式。近年来,提出了许多具有复杂架构的神经网络来解决这一问题。然而,网络架构的进步遇到了性能提升的瓶颈。在本研究中,我们提出了一种称为时空自适应嵌入的新组件,该组件在使用标准Transformer的情况下仍能产生卓越的结果。我们提出的时空自适应嵌入Transformer(STAEformer)在五个真实世界的交通预测数据集上达到了最先进的性能。进一步的实验表明,时空自适应嵌入在交通预测中发挥了关键作用,有效捕捉了交通时间序列中内在的时空关系和时间信息。

CCS概念

  • 信息系统 → 时空系统;
  • 计算方法 → 人工智能。

1 介绍

交通预测旨在基于历史观测预测道路网络中的未来交通时间序列。近年来,深度学习模型的成功尤为显著,主要归因于其能够捕捉交通系统中固有的时空依赖性。其中,时空图神经网络 (STGNNs) 和基于Transformer的模型因其出色的性能变得非常流行。研究人员花费了大量精力开发复杂的交通预测模型,例如新的图卷积、学习图结构、高效的注意力机制,以及其他方法。然而,网络架构的进展遇到了性能提升的瓶颈,促使人们从复杂的模型设计转向更有效的表征技术。

image-20241011211043012

基于此,在本研究中,我们聚焦于输入嵌入,一种广泛使用的简单但强大的表征技术,许多研究人员往往在其有效性方面忽略了它。具体来说,它在输入上添加了一个嵌入层,为模型骨干提供了多种类型的嵌入。图1展示了之前模型中所采用嵌入的对比分析。STGNNs主要使用特征嵌入 E f E_f Ef,即一种转换将原始输入投射到隐藏空间。基于Transformer的模型需要额外的知识,如时间位置编码 E t p e E_{tpe} Etpe 和周期性(日常、每周、每月)嵌入 E p E_p Ep,这是由于注意力机制无法保留时间序列的位置信息。最近的模型,包括PDFOrmer、GMAN 和 STID,都应用了空间嵌入 E s E_s Es。值得注意的是,STID 是少数研究这些嵌入的研究之一。它采用了空间嵌入和时间周期嵌入,并结合了简单的多层感知机 (MLP),取得了显著的性能。

为了进一步增强表征的有效性,我们提出了一种新的时空自适应嵌入 E a E_a Ea,并将其与 E p E_p Ep E f E_f Ef 一起应用于vanilla Transformer,如图1d所示。具体来说,原始输入通过嵌入层获取输入嵌入,输入嵌入被传递给时空Transformer层,然后经过回归层以做出预测。我们提出的模型命名为时空自适应嵌入Transformer (STAEformer),其架构更加简洁,但在性能上达到了最新的SOTA(state-of-the-art)。在我们的模型中, E a E_a Ea 通过有效捕捉交通时间序列中的内在时空关系和时间信息,发挥了关键作用。实验和对五个真实交通数据集的分析证明,我们提出的 E a E_a Ea 能够使vanilla Transformer在交通预测中达到SOTA水平。

2 问题定义

给定过去 T T T 个时间帧的交通时间序列 X t − T + 1 : t X_{t-T+1:t} XtT+1:t,交通预测旨在推断未来 T ′ T' T 个时间帧的交通数据,公式如下:

[ X t − T + 1 , … , X t ] → [ X t + 1 , … , X t + T ′ ] , [X_{t-T+1}, \dots, X_t] \rightarrow [X_{t+1}, \dots, X_{t+T'}], [XtT+1,,Xt][Xt+1,,Xt+T]

其中每个帧 X i ∈ R N × d X_i \in \mathbb{R}^{N \times d} XiRN×d N N N 是空间节点的数量, d = 1 d=1 d=1 是交通量的输入维度。

3 方法

如图2所示,我们的模型由一个嵌入层、沿时间轴应用的vanilla transformer作为时间Transformer层、沿空间轴应用的空间Transformer层和回归层组成。

image-20241011211204260

3.1 嵌入层

为了保持原始数据中的固有信息,我们利用全连接层来获得特征嵌入 E f ∈ R T × N × d f E_f \in \mathbb{R}^{T \times N \times d_f} EfRT×N×df

E f = F C ( X t − T + 1 : t ) E_f = FC(X_{t-T+1:t}) Ef=FC(XtT+1:t)

其中 d f d_f df 是特征嵌入的维度, F C ( ⋅ ) FC(\cdot) FC() 表示全连接层。

我们将可学习的星期几嵌入词典表示为 T w ∈ R N w × d f T_w \in \mathbb{R}^{N_w \times d_f} TwRNw×df,时间戳嵌入词典表示为 T d ∈ R N d × d f T_d \in \mathbb{R}^{N_d \times d_f} TdRNd×df,其中 N w = 7 N_w=7 Nw=7 表示每周的天数, N d = 288 N_d=288 Nd=288 表示每天的时间戳数量。我们将 D t ∈ R T D^t \in \mathbb{R}^{T} DtRT 作为星期几数据,时间戳数据为 t − T + 1 : t t-T+1:t tT+1:t 中的交通时间序列,使用它们作为索引来提取对应的星期几嵌入 E w t ∈ R T × d f E_w^t \in \mathbb{R}^{T \times d_f} EwtRT×df 和时间戳嵌入 E d t ∈ R T × d f E_d^t \in \mathbb{R}^{T \times d_f} EdtRT×df。通过拼接和广播,我们获得了交通时间序列的周期嵌入 E p ∈ R T × N × 2 d f E_p \in \mathbb{R}^{T \times N \times 2d_f} EpRT×N×2df

从直觉上讲,交通时间序列中的时间帧应与相邻时间帧更加相似。另一方面,不同传感器的时间序列可能具有不同的时间模式。因此,我们设计了一个时空自适应嵌入 E a ∈ R T × N × d a E_a \in \mathbb{R}^{T \times N \times d_a} EaRT×N×da,以一种统一的方式捕捉复杂的时空关系。特别地, E a E_a Ea 在不同的交通时间序列之间共享。

通过拼接上述嵌入,我们得到了隐藏的时空表征 Z ∈ R T × N × d h Z \in \mathbb{R}^{T \times N \times d_h} ZRT×N×dh,公式如下:

Z = E f ∣ ∣ E p ∣ ∣ E a Z = E_f || E_p || E_a Z=Ef∣∣Ep∣∣Ea

其中隐藏维度 d h = 3 d f + d a d_h = 3d_f + d_a dh=3df+da

3.2 Transformer与回归层

我们沿时间和空间轴应用vanilla transformer来捕捉复杂的交通关系。给定 T T T 帧和 N N N 空间节点的隐藏时空表征 Z ∈ R T × N × d h Z \in \mathbb{R}^{T \times N \times d_h} ZRT×N×dh,我们通过时间Transformer层获得查询、键和值矩阵:

Q ( t e ) = Z W Q ( t e ) , K ( t e ) = Z W K ( t e ) , V ( t e ) = Z W V ( t e ) Q^{(te)} = ZW_Q^{(te)}, \quad K^{(te)} = ZW_K^{(te)}, \quad V^{(te)} = ZW_V^{(te)} Q(te)=ZWQ(te),K(te)=ZWK(te),V(te)=ZWV(te)

其中 W Q ( t e ) , W K ( t e ) , W V ( t e ) ∈ R d h × d h W_Q^{(te)}, W_K^{(te)}, W_V^{(te)} \in \mathbb{R}^{d_h \times d_h} WQ(te),WK(te),WV(te)Rdh×dh 是可学习参数。然后我们计算自注意力得分:

A ( t e ) = Softmax ( Q ( t e ) K ( t e ) ⊤ d h ) A^{(te)} = \text{Softmax} \left( \frac{Q^{(te)} {K^{(te)}}^\top}{\sqrt{d_h}} \right) A(te)=Softmax(dh Q(te)K(te))

A ( t e ) ∈ R N × T × T A^{(te)} \in \mathbb{R}^{N \times T \times T} A(te)RN×T×T 捕捉不同空间节点中的时间关系。最终,我们得到时间Transformer层的输出 Z ( t e ) ∈ R T × N × d h Z^{(te)} \in \mathbb{R}^{T \times N \times d_h} Z(te)RT×N×dh,公式如下:

Z ( t e ) = A ( t e ) V ( t e ) Z^{(te)} = A^{(te)} V^{(te)} Z(te)=A(te)V(te)

类似地,空间Transformer层的输出为:

Z ( s p ) = SelfAttention ( Z ( t e ) ) Z^{(sp)} = \text{SelfAttention}(Z^{(te)}) Z(sp)=SelfAttention(Z(te))

其中自注意力机制遵循之前的公式。我们还应用了层归一化、残差连接和多头机制。

最后,我们利用时空Transformer层的输出 Z ′ ∈ R T ′ × N × d h Z' \in \mathbb{R}^{T' \times N \times d_h} ZRT×N×dh 来生成预测。回归层的公式如下:

Y ^ = F C ( Z ′ ) \hat{Y} = FC(Z') Y^=FC(Z)

其中 Y ^ ∈ R T ′ × N × d \hat{Y} \in \mathbb{R}^{T' \times N \times d} Y^RT×N×d 是预测, T ′ T' T 是预测的时间范围, d d d 是输出特征的维度,在我们的例子中 d = 1 d=1 d=1。因此,全连接层将 Z ′ Z' Z 中的维度从 T × d h T \times d_h T×dh 回归到 T ′ × ( d = 1 ) T' \times (d=1) T×(d=1)

4 实验

4.1 实验设置

设置。 我们的方法在五个交通预测基准上进行了验证(METR-LA、PEMS-BAY、PEMS04、PEMS07 和 PEMS08)。这五个数据集中的交通数据被聚合为5分钟间隔。因此,每小时有12个时间帧。更多详细信息如表2所示。

image-20241011211320566

METR-LA 和 PEMS-BAY 被划分为训练集、验证集和测试集,比例为7:1:2。PEMS04、PEMS07 和 PEMS08 则按6:2:2的比例划分。事实上,我们模型的性能对超参数不敏感。更详细的信息为:特征嵌入维度 d f d_f df 为24,自适应嵌入维度 d a d_a da 为80。空间和时间Transformer的层数 L L L 为3,头数为4。我们将输入和预测长度设置为1小时,即 T = T ′ = 12 T=T'=12 T=T=12。Adam被选为优化器,学习率从0.001开始衰减,批量大小为16。如果验证误差在连续30步内收敛,我们将采用提前终止机制。我们使用三个广泛用于交通预测任务的指标,即MAE、RMSE 和 MAPE。根据以往的工作,我们选择PEMS04、PEMS07 和 PEMS08数据集中预测12个时间范围的平均表现。为了评估METR-LA 和 PEMS-BAY数据集,我们比较了在3、6 和12步(分别为15、30 和60分钟)上的表现。

基线模型。 在本研究中,我们将提出的方法与领域中几种广泛使用的基线模型进行了比较。HI 是一个典型的传统模型。我们还考虑了STGNNs模型,例如GWNet、DCRNN、AGCRN、STGCN、GTS 和 MTGNN,它们使用了图1(a)中展示的嵌入。此外,我们考察了STNorm模型,它专注于分解交通时间序列。尽管存在基于Transformer的方法,如Informer、Pyraformer、FEDformer 和 Autoformer,但它们并未专门为短期交通预测设计。因此,我们选择了与我们任务相同的Transformer模型GMAN 和 PDFOrmer。文献 [42] 和 [14] 中的输入嵌入配置如图1(b)所示。此外,我们考虑了STID,该模型通过利用图1©中的输入嵌入增强了交通时间序列中的时空差异。

4.2 性能评估

如表1和表3所示,我们的方法在所有五个数据集的绝大多数指标上取得了更好的性能。STAEformer 在没有任何图结构建模的情况下大幅优于STGNNs。STNorm 和 STID 也获得了具有竞争力的结果,而基于Transformer的模型则能更好地捕捉复杂的时空关系。与PDFOrmer相比,令人鼓舞的结果表明STAEformer是一个更简单但更有效的解决方案。

image-20241011211225897

image-20241011211238504

4.3 消融实验

为了评估STAEformer中每个部分的有效性,我们进行了消融实验,使用了我们模型的四个变体,如下所示:

  • 去掉 E a E_a Ea:移除时空自适应嵌入 E a E_a Ea
  • 去掉 E p E_p Ep:移除周期嵌入 E p E_p Ep,包括星期几和一天中的时间戳嵌入。
  • 去掉时间Transformer层:移除时间Transformer层。
  • 去掉时空Transformer层:移除时间和空间Transformer层。

image-20241011211302674

表4揭示了各种嵌入对模型性能的重要性。 E p E_p Ep 可以捕捉到日常和每周的模式,而提出的 E a E_a Ea 对交通建模至关重要。此外,在去除空间或时间Transformer层时,我们观察到性能大幅下降,这表明我们的嵌入模型能够有效建模交通数据中的固有时空模式。因此,空间和时间层对于提取这些特征是必要的。

4.4 案例研究

与空间嵌入的比较。 为了验证时空自适应嵌入在捕捉输入帧 T T T 中隐含的时间顺序信息方面的有效性,我们在PEMS04和PEMS08数据集上进行了更多实验。我们随机打乱了沿时间轴 T T T 的原始输入。为了比较,我们将时空自适应嵌入 E a E_a Ea 替换为文献中使用的空间嵌入 E s E_s Es。如图3所示,当打乱输入时,我们的模型表现出更严重的性能下降。这意味着时空自适应嵌入 E a E_a Ea 使我们的模型对时间顺序更加敏感,而使用 E s E_s Es 的模型对顺序的敏感性较低。总结来说, E a E_a Ea 能够更好地建模输入中的时间顺序信息以及其他复杂的交通模式,这对于任务来说至关重要。

image-20241011211339979

时空自适应嵌入的可视化。 图4进一步提供了我们提出的时空自适应嵌入 E a E_a Ea 在空间和时间轴上的可视化,以PEMS08数据集为例。在空间轴上,我们使用t-SNE获得图4a。结果显示,不同节点的嵌入自然地聚成簇,匹配了交通数据的空间特征。在时间轴上,我们计算了12个输入帧中各帧之间的相关系数,并绘制了图4b中的热力图。结果显示,每个时间帧与相邻帧的相关性较高,而对于更远的帧,相关性逐渐降低。结果表明,我们提出的 E a E_a Ea 能够正确建模时间序列中的时间信息。

image-20241011211402209

5 结论

在本研究中,我们专注于一种用于交通时间序列预测的基础表征学习技术,即输入嵌入。我们提出了一种新的时空自适应嵌入,能够在vanilla transformers上实现交通预测基准测试中的最先进(SOTA)性能。进一步的研究表明,我们的模型可以有效捕捉交通中的内在时空依赖性。与设计复杂模型不同,我们的研究展示了一条解决交通预测挑战的有前途的方向。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

holdoulu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值