论文题目 | PredFormer: Transformers Are Effective Spatial-Temporal Predictive Learners |
---|---|
论文链接 | https://openreview.net/forum?id=avNVrQ8D2v |
源码地址 | https://github.com/yyyujintang/PredFormer (Coming soon) |
关键词 | Transformer、时空预测、Gated Transformer Blocks(GTB) |
摘要
时空预测学习方法通常分为两类:基于循环的方式,这种方式在并行化和性能上面临挑战;以及无循环的方式,后者采用卷积神经网络(CNN)作为编码-解码架构。这些方法受益于强大的归纳偏置,但通常以可扩展性和泛化能力为代价。本文提出了PredFormer,一种纯基于Transformer的时空预测学习框架。受视觉Transformer(ViT)设计的启发,PredFormer利用精心设计的门控Transformer模块,经过对包括全注意力、因子分解以及交错时空注意力在内的3D注意力机制的全面分析。通过无循环的Transformer设计,PredFormer不仅简单且高效,在性能上大幅超越了以往的方法。对合成数据集和真实世界数据集的大量实验表明,PredFormer实现了最新的性能表现。在Moving MNIST数据集上,PredFormer相较于SimVP减少了51.3%的MSE。在TaxiBJ数据集上,该模型将MSE降低了33.1%,并将FPS从533提升到了2364。此外,在WeatherBench数据集上,PredFormer将MSE降低了11.1%,同时将FPS从196提高到了404。这些在精度和效率上的提升展示了PredFormer在实际应用中的潜力。源码和预训练模型将向公众开放。
PredFormer
1 介绍
时空预测学习通过基于过去的观测来预测未来的帧,从而学习空间和时间模式。这种能力在天气预报、交通流预测、降水预报和人体运动预测等应用中至关重要。
尽管各种时空预测学习方法取得了成功,但它们通常难以在计算成本和性能之间取得平衡。一方面,基于循环的强大方法依赖于自回归的RNN框架,然而这些方法在并行化和计算效率方面面临显著限制。另一方面,基于CNN的无循环方法虽然提高了效率,但由于局部感受野的限制,扩展性和泛化能力有限。这引发了一个更为基础的问题:我们是否可以开发一个框架,自主学习时空依赖关系,而不依赖归纳偏置?
图2:时空预测学习框架的主要类别。(a) 基于循环的框架 (b) 基于卷积神经网络编码-解码的无循环框架 ( c) 基于纯Transformer的无循环框架。
一种直观的解决方案是直接采用纯Transformer结构,因为它是RNN的高效替代方案,且比CNN更具可扩展性。Transformer在视觉任务中已表现出显著的成功。尽管已有的方法试图将Swin Transformer集成到RNN框架中,或将MetaFormer作为时间翻译器集成到无循环的CNN编码器-解码器框架中,但纯Transformer架构仍然主要处于探索阶段,特别是在捕捉统一框架中的时空关系方面存在挑战。尽管将空间和时间维度合并并应用全注意力的概念在理论上是可行的,但由于注意力与序列长度成平方的扩展,使得这种方法在计算上非常昂贵。为了减少复杂性,最近的几种方法通过因子分解或交错方式分别处理空间和时间关系。
在这项工作中,我们提出了PredFormer,一种纯基于Transformer的时空预测学习架构。PredFormer深入分解空间和时间Transformer,通过与门控线性单元(GLU)的自注意力结合,更有效地捕捉复杂的时空动态。除了保留空间优先和时间优先配置的全注意力编码器和因子分解编码器策略外,我们还引入了六种新颖的交错时空Transformer架构,共产生九种配置。这种探索是为了应对不同任务和数据集的不同空间和时间分辨率及依赖关系。通过全面的调查,推动了当前模型的边界,并为时空建模设定了有价值的基准。
特别地,PredFormer在三个基准数据集上取得了最先进的性能,包括合成的移动物体预测、交通流预测和天气预报。在不依赖复杂模型架构或专用损失函数的情况下,PredFormer以较大幅度超越了以往的方法。此外,我们的最优模型在性能上表现出色,参数更少,FLOP更低,推理速度更快,展示了其在实际应用中的巨大潜力。
主要贡献如下:
- 我们提出了PredFormer,一种纯基于门控Transformer的时空预测学习模型。通过消除CNN中固有的归纳偏置,PredFormer利用了Transformer的可扩展性和泛化能力,使其成为一个高度可适应的模型,显著提高了潜力和性能上限。
- 我们对时空Transformer因子分解进行了深入分析,探索了全注意力编码器和因子分解编码器,以及交错时空Transformer架构,共得出了九种PredFormer变体。这些变体针对不同任务和数据集的空间和时间分辨率,优化了性能。
- 据我们所知,PredFormer是第一个用于时空预测学习的纯Transformer模型。我们对从头开始在小数据集上训练ViT进行了全面研究,探索了正则化和位置编码技术。
- 大量实验表明,PredFormer表现出卓越的性能。与SimVP相比,PredFormer在Moving MNIST上将MSE降低了51.3%,在TaxiBJ上降低了33.1%,同时将FPS从533提高到2364,在WeatherBench上将MSE降低了11.1%,并将FPS从196提高到404。
2 相关工作
图1:(a) PredRNN、SimVP和PredFormer的性能表现;(b) 模型效率对比。图中位置越靠内的模型表示准确率和效率越高。
基于循环的时空预测学习
基于循环的时空预测模型的最新进展整合了CNN、ViT和Vision Mamba等结构到RNN中,采用多种策略来捕捉时空关系。ConvLSTM通过将卷积操作集成到LSTM框架中创新性地提出。PredNet利用自底向上的连接和自顶向下的连接来预测未来的视频帧。PredRNN引入了时空LSTM单元(ST-LSTM),通过传播隐藏状态有效地捕捉并记忆空间和时间表示。PredRNN++则通过引入梯度高速公路单元和因果LSTM来解决梯度消失问题,并自适应地捕捉时间依赖性。E3D-LSTM扩展了ST-LSTM的记忆能力,通过集成3D卷积。MIM模型进一步优化了ST-LSTM,重新设计了遗忘门,使用双循环单元并在隐藏状态之间利用差异信息。CrevNet使用基于CNN的可逆架构来有效地解码复杂的时空模式。PredRNNv2通过引入记忆解耦损失和课程学习策略来增强PredRNN。MAU设计了专门用于捕捉动态运动信息的运动感知单元。SwinLSTM则将Swin Transformer模块集成到LSTM架构中,而VMRNN扩展了该方法。
无循环的时空预测学习
最新的无循环模型,如SimVP,基于CNN编码器-解码器设计。TAU通过将时间注意力分离为静态帧内和动态帧间成分,并引入差异散度损失来监督帧间变化。OpenSTL集成了MetaFormer模型作为时间翻译器。此外,PhyDNet将物理原理引入CNN架构,而DMVFN则引入了动态多尺度体素流网络来增强视频预测性能。WAST提出了一种基于小波的损失函数。与之前的方法相比,PredFormer在其无循环的纯Transformer架构中通过利用全局感受野,提升了时空学习,超越了之前的模型表现。
视觉Transformer (ViT)
ViT在各种视觉任务中展示了出色的性能。在视频处理领域,TimeSformer研究了空间和时间自注意力的因子分解,并提出分别应用时间和空间注意力的方案以获得最佳准确率。ViViT探讨了因子分解编码器、自注意力和点积机制,得出的结论是优先应用空间注意力的因子分解编码器表现更好。TSViT发现优先应用时间注意力的因子分解编码器能够取得更好的结果。尽管取得了这些进展,目前大多数现有模型主要集中在视频分类领域,较少有研究将ViT应用于时空预测学习。PredFormer通过将自注意力与门控线性单元相结合,进一步深入分解时空Transformer,能够更强大地捕捉复杂的时空动态。
3 方法
为了系统地分析网络模型在时空预测学习中的Transformer结构,我们提出了PredFormer作为通用模型设计。
图3:(a) PredFormer模型框架概述。(b) 从空间视角和时间视角的序列分解。( c) 门控Transformer模块。(d) 门控线性单元。
3.1 纯基于Transformer的架构
Patch Embedding:按照ViT的设计,PredFormer将帧序列 X X X 切分为大小为 p p p 的非重叠patch,生成序列 N = ⌊ H p ⌋ × ⌊ W p ⌋ N = \left\lfloor \frac{H}{p} \right\rfloor \times \left\lfloor \frac{W}{p} \right\rfloor N=⌊pH⌋×⌊pW⌋ ,每个patch被扁平化为一维token。这些token被线性投影到隐藏维度 D D D ,并通过层归一化(LN)处理,生成张量 X ′ ∈ R B × T × N × D X' \in \mathbb{R}^{B \times T \times N \times D} X′∈RB×T×N×D。
Position Encoding:不同于典型的ViT方法,我们引入了二维时空位置编码(PE),该编码通过正弦函数生成并为每个patch分配绝对坐标。
PredFormer Encoder:这些一维token随后通过PredFormer编码器进行特征提取。PredFormer编码器由门控Transformer块以不同方式堆叠而成。
Patch Recovery:由于我们的编码器基于纯门控Transformer,不涉及卷积或分辨率减少,全球上下文在每一层建模。这允许其与简单的解码器配对,形成强大的预测模型。解码器将线性层作为解码器,将隐藏维度投影回去以恢复二维patch。
3.2 门控Transformer块
标准Transformer模型在多头注意力(MSA)和前馈网络(FFN)之间交替。每个头的注意力机制定义为:
Attention ( Q , K , V ) = Softmax ( Q K ⊤ d k ) V \text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V Attention(Q,K,V)=Softmax(dkQK⊤)V
其中,在自注意力中,查询 Q Q Q、键 K K K 和值 V V V 是输入 X X X 的线性投影,表示为 Q = X W q Q = XW_q Q=XWq, K = X W k K = XW_k K=XWk, V = X W v V = XW_v V=XWv,其中 X , Q , K , V ∈ R N × d X, Q, K, V \in \mathbb{R}^{N \times d} X,Q,K,V∈RN×d。FFN然后通过应用两个线性变换处理序列中的每个位置。
门控线性单元(GLU)经常用于替代简单的线性变换,涉及两个线性投影的逐元素乘积,其中一个投影通过sigmoid函数。不同的GLU变体通过用其他非线性函数替代sigmoid来控制信息流。例如,SwiGLU将sigmoid替换为Swish激活函数(SiLU),如公式(2)所示:
Swish β ( x ) = x σ ( β x ) \text{Swish}_{\beta}(x) = x\sigma(\beta x) Swishβ(x)=xσ(βx)
SwiGLU ( x , W , V , b , c , β ) = Swish β ( x W + b ) ⊗ ( x V + c ) \text{SwiGLU}(x, W, V, b, c, \beta) = \text{Swish}_{\beta}(xW + b) \otimes (xV + c) SwiGLU(x,W,V,b,c,β)=Swishβ(xW+b)⊗(xV+c)
SwiGLU在各种自然语言处理任务中表现优于多层感知器(MLP)。受SwiGLU成功的启发,我们的门控Transformer块(GTB)包含了MSA和基于SwiGLU的FFN,定义如下:
Y l = MSA ( LN ( Z l ) ) + Z l Y^l = \text{MSA}(\text{LN}(Z^l)) + Z^l Yl=MSA(LN(Zl))+Zl
Z l + 1 = SwiGLU ( LN ( Y l ) ) + Y l Z^{l+1} = \text{SwiGLU}(\text{LN}(Y^l)) + Y^l Zl+1=SwiGLU(LN(Yl))+Yl
3.3 PredFormer的变体
在预测学习中建模时空依赖性具有挑战性,因为空间和时间信息在不同任务和数据集上有所差异。这需要灵活、自适应的模型以适应不同的依赖性和尺度。为了解决这些问题,我们探讨了空间优先(Fac-S-T)和时间优先(Fac-T-S)配置的全注意力编码器和因子分解编码器。此外,我们基于PredFormer层引入了六种交错模型,使其能够在多个尺度之间进行动态交互。
PredFormer层是一个能够同时处理空间和时间信息的模块。基于该设计原则,我们提出了三种交错的时空范式:二元层、三元层和四元层,它们依次从空间视角和时间视角建模,如**图3(c)**所示。最终产生了六种不同的架构配置。图4中提供了这九种变体的详细说明。
图4: (a) 全注意力层和Binary-TS层的数据转换 (b) 全注意力编码器与因子化编码器 ( c) 交错编码器,包括Binary、Triplet和Quadrupled设计
对于全注意力层,给定输入 X ∈ R B × T × N × D X \in \mathbb{R}^{B \times T \times N \times D} X∈RB×T×N×D ,注意力计算跨越 T × N T \times N T×N 序列长度进行。我们通过几个堆叠的GTB计算注意力。
对于二元层,每个GTB块独立处理时间或空间序列,我们将其表示为二元-TS或二元-ST层。输入首先被重塑,并通过GTB的时间块 G T B t 1 GTB^1_t GTBt1 处理,其中在时间序列上应用注意力。然后将张量重塑回来以恢复时间顺序,随后使用另一个空间块 G T B s 2 GTB^2_s GTBs2 应用空间注意力,在时间维度上展平张量并进行处理。
对于三元和四元层,在二元结构之上堆叠了额外的块。三元-TST捕捉更多的时间依赖性,而三元-STS则更侧重于空间依赖性,两者使用相同数量的参数。四元层结合了两个不同顺序的二元层。我们省略了进一步的详细解释。
4 实验
我们对PredFormer和当前最先进的模型进行了广泛的评估。我们在合成和真实场景下进行实验,包括长期预测(移动物体轨迹预测和天气预报)和短期预测(交通流量预测)。数据集的统计信息在表1中展示。这些数据集具有不同的空间分辨率、时间帧和间隔,决定了它们不同的时空依赖关系。
实现细节。我们的方法使用PyTorch实现,实验在24GB的NVIDIA RTX 3090和24GB A5000 GPUs上进行,除非特别说明,所有实验都在单个GPU上运行。PredFormer使用AdamW优化器进行优化,结合L2损失、权重衰减值为1e-2,学习率在{5e-4, 1e-3}之间选择以获得最佳性能。Moving MNIST和TaxiBJ使用OneCycle调度器,WeatherBench则使用余弦调度器。为了防止TaxiBJ和WeatherBench的过拟合,使用Dropout和随机深度正则化方法。附录A.2中提供了更多超参数的详细信息。
对于不同的PredFormer变体,我们保持常量的GTB块数量以确保参数的可比性。在三元模型无法平均划分的情况下,我们使用最接近的GTB块数量。
评估指标。我们通过三个维度的指标评估模型性能:
- 像素级误差 使用均方误差(MSE)、平均绝对误差(MAE)和均方根误差(RMSE)进行衡量。
- 预测帧质量 通过结构相似性指数(SSIM)评估,MSE、MAE和RMSE的较低值以及较高的SSIM表明更好的预测效果。
- 计算效率 通过参数数量、浮点运算次数(FLOPs)和推理速度(FPS)进行衡量。我们使用NVIDIA A5000 GPU在秒帧率(FPS)下评估。这种多维度评估框架全面评估了模型的准确性、效率和可扩展性。
4.1 长期预测:Moving MNIST
Moving MNIST. Moving MNIST数据集作为评估序列预测模型的合成基准数据集。我们按照标准生成Moving MNIST序列,每个序列包含20帧,使用前10帧作为输入,后10帧作为预测目标。我们采用10000个序列用于训练,并为了公平比较,使用预生成的10000个序列作为验证集。
在Moving MNIST数据集中,作为最常用的基准数据集,我们采用两种训练设置来探索PredFormer框架的性能、收敛性、效率和变体。在第一个设置中,我们训练200个epochs来比较我们提出的9个模型与SimVP和TAU的表现,我们的定量结果展示在表2中。
表 2:Moving MNIST上的定量比较。每个模型观察10帧并预测接下来的10帧。我们训练模型200个epoch,并引用原论文的其他结果。
在第二个设置中,依据以往的研究,我们训练从200个epochs的运行中选出的最佳模型至2000个epochs,并在表3中报告最终结果。我们引用了每篇原始论文中的所有其他方法的结果,以便进行公平比较。
表 3:Moving MNIST上的定量比较。每个模型观察10帧并预测接下来的10帧。我们训练模型2000个epoch,并引用原论文的其他结果。
PredFormer能否比SimVP收敛更快? 当使用补丁大小为4时,我们的六个交错模型在仅仅200个epochs的训练中超过了SimVP在2000个epochs中的表现(MSE为23.8)。这表明PredFormer在有限的epochs中表现出更快的收敛速度,同时保持了优异的性能,突显了纯ViT框架相较于基于CNN方法的效率和鲁棒性。
ViT与CNN框架的上限比较。我们将补丁大小为4、在200个epochs中表现最佳的模型(Triplet-STS,MSE为20.7)扩展到2000个epochs,MSE显著降低至11.6,超越了SimVP 51.3%,并超越了TAU 41.4%。这些结果表明,我们的纯Transformer模型在很大程度上超越了以往的所有方法。CNN受限于归纳偏置,而它们难以匹配纯Transformer架构所具有的全局感受野优势,进一步强调了PredFormer在时空建模中的更高上限。
准确性与效率的权衡。在补丁大小为4的情况下,尽管参数量少于SimVP,PredFormer具有更高的FLOPs和较低的FPS。我们将补丁大小增加到8以平衡性能与效率,将计算减少到原来的四分之一。在这种配置下,FLOPs降至16.4G,低于SimVP的19.4G,并且与TAU的16.0G相当,FPS略低于SimVP。训练200个epochs时,PredFormer的MSE高于SimVP但低于TAU的200个epochs结果。扩展训练至2000个epochs后,SimVP的MSE从32.2降至23.8,TAU从24.6降至19.8,而我们的PredFormer显示出更大的改进,从26.0降至12.5。这再次证明了纯Transformer模型相较于CNN模型,即使在更大的补丁大小下仍然表现出更高的上限。具体来说,PredFormer相较于SimVP提高了47.5%,相较于TAU提高了36.9%,实现了令人印象深刻的准确性-效率权衡,并获得了显著的性能提升。
PredFormer的变体。在我们提出的变体中,出现了几种趋势:(1)补丁大小为4的200-epoch实验中:Fac-T-S模型表现优于全注意力模型,超过了Fac-S-T模型。交错模型明显优于因子分解模型和全注意力模型,MSE值范围在20到21之间。在这些模型中,Triplet-TST模型取得了最佳结果。(2)补丁大小为8的200-epoch实验中:交错模型持续超越了全注意力模型和因子分解模型,出现了一个明显的趋势:时间优先模型表现优于空间优先模型。值得注意的是,四元-TSST优于三元-TSTS,三元-STS优于三元-TST,而二元-TS略优于二元-ST。这表明,对于补丁大小为8的长期10$\rightarrow$10预测任务,时间依赖性起到了更关键的作用。(3)补丁大小为4的2000-epoch实验中:三元-STS略优于三元-TST,MSE达到11.6。这种差异可能归因于更长的空间序列和较小的补丁大小,在这种情况下,空间依赖性变得更为重要。(4)补丁大小为8的2000-epoch实验中:四元-TSST超越了三元-TST和二元-TS,MSE为12.5。
4.2 长期预测:WeatherBench
WeatherBench. 气候预测是时空预测学习中的关键挑战。WeatherBench数据集提供了一个全面的全球天气预报资源,涵盖了各种气候因素。在我们的实验中,我们使用WeatherBench-S的单变量设置,其中每个气候因子独立训练。我们专注于5.625°分辨率(32 × 64网格点)的温度预测。该模型基于2010年至2015年的数据进行训练,使用2016年的数据进行验证,并使用2017年至2018年的数据进行测试,所有数据间隔为1小时。在此设置中,我们输入前12帧并预测后续12帧。
定量评估。在表4中展示了我们在WeatherBench上的定量评估结果。
表 4:WeatherBench (T2m)上的定量比较。每个模型观察12帧并预测接下来的12帧。我们引用了其他来自原论文的结果。
我们得出了以下结论:
- 第一个结论与Moving MNIST一致,Fac-T-S模型优于全注意力模型,而后者又优于Fac-S-T模型。Fac-T-S模型在所有评估中表现最佳,MSE为1.100。
- 此外,六个交错模型显著优于所有其他基线,MSE值范围为1.108到1.149。值得注意的是,Triplet-TST模型获得了第二好的结果,MSE为1.108。
- Fac-T-S模型相较于SimVP在MSE上提高了11.1%,相较于TAU提高了5.9%。
- 有趣的是,最佳的Fac-T-S模型和次优的Triplet-TST模型都从时间模块开始。Triplet-TST模型强调时间依赖性而非空间依赖性,并以比Fac-T-S更少的参数实现了可比的结果。这表明,对于12$\rightarrow$12的长期预测任务,时间依赖性更为关键。
效率。我们的Fac-T-S模型表现出较强的性能并且需要较少的参数(从14.8M减少到5.3M)。虽然Fac-T-S模型的FLOPs与SimVP(8.6G)相当,但其FPS从196提高到404。此外,表现第二好的Binary-TST模型在效率和性能上都非常出色。这些结果表明我们的模型在现实世界的天气预报应用中具有巨大的潜力。
4.3 短期预测:TaxiBJ
TaxiBJ. TaxiBJ数据集包含来自北京的出租车GPS数据和气象数据。每个数据帧被可视化为一个32 × 32 × 2的热图,其中第三维度表示在指定区域内的交通流量进出。依据之前的工作,我们将最后四周的数据用于测试,利用前面的数据进行训练。我们的预测模型使用四个连续的观测值来预测随后的四帧。
定量评估。在表5中,我们展示了TaxiBJ的定量结果。
表 5:TaxiBJ上的定量比较。每个模型观察4帧并预测接下来的4帧。我们引用了原论文的其他结果。
我们得出了以下发现:
- 在全注意力和因子分解编码器模型中,Fac-T-S模型优于全注意力模型,而后者又优于Fac-S-T模型。
- 交错模型显著优于全注意力、Fac-S-T模型以及所有其他基线方法,MSE值范围在0.277至0.293之间。值得注意的是,Binary-ST和Triplet-STS的表现最佳。
- Triplet-STS模型相较于SimVP提高了33.1%,相较于TAU提高了19.5%。
- 有趣的是,两个表现最佳的模型都从空间模块开始,而Triplet-STS模型强调空间依赖性多于时间依赖性,以比Binary-ST更少的参数实现了可比的结果。这表明对于4$\rightarrow$4的短期预测任务,空间依赖性更为关键。
效率。我们的Triplet-STS模型实现了较强的性能,参数量更少,FLOPs更低,FPS更高。PredFormer将SimVP的参数量从13.8M减少到6.3M,将FLOPs从3.6G减少到1.6G,且将FPS从533提高到2364。这些结果表明模型在短期交通流预测中具有重要的现实应用潜力。
4.4 消融研究与讨论
我们对PredFormer模型设计进行了消融研究,并在表6中总结了结果。我们选择了在Moving MNIST上最佳的Triplet-TST-ps4 200-epoch模型,在TaxiBJ上最佳的Triplet-STS模型,以及在WeatherBench上最佳的Fac-T-S模型作为基线。
表6:关于门控线性单元和位置编码的消融研究。
门控线性单元(GLU)。用标准MLP替换SwiGLU会导致显著的性能下降。在Moving MNIST上,MSE从20.5上升至22.6;在TaxiBJ上,从0.277上升至0.306;在WeatherBench上,从1.100上升至1.171。这一性能下降表明了门控机制在建模复杂时空动态中的关键作用。
位置编码。同样,当我们将模型中的绝对位置编码替换为ViT中常用的可学习时空位置编码时,性能会恶化。在Moving MNIST上,MSE从20.5上升至22.2;在TaxiBJ上,从0.277上升至0.288;在WeatherBench上,从1.100上升至1.164。这些消融实验一致地展示了三个数据集中的类似趋势,强调了我们位置编码设计的鲁棒性。
模型正则化。纯粹的Transformer架构(如ViT)通常需要大型数据集进行有效训练,当应用于较小的数据集时,过拟合可能成为一个挑战。在我们的实验中,在WeatherBench和TaxiBJ数据集上都明显存在过拟合现象。我们在表7中实验了不同的正则化技术,发现dropout (DP) 和随机深度 (SD) 各自相比没有正则化时都能提高性能。然而,这两者结合使用能产生最佳效果。与传统的ViT实践不同,我们的任务中,均匀的drop path rate相比于随层次深度线性缩放的drop path rate表现显著更好。我们为所有九个变体采用了相同的正则化设置。
表7:关于Dropout和随机深度的消融研究。
可视化。图5和附录图6展示了我们PredFormer模型的预测结果及其在三个基准数据集上的相关预测误差。可视化结果表明,PredFormer模型大大减少了与TAU相比的预测误差,并提供了更准确的预测。我们在附录图7中展示了另一个案例,进一步证明了PredFormer相较于TAU的优异泛化能力。
图5:(a) Moving MNIST和(b) TaxiBJ的可视化。误差 = |预测值 - 目标值|。我们放大了误差以便更好地进行比较。
图6:WeatherBench全球温度预测的可视化。
关于PredFormer的讨论。尽管我们对时空分解进行了深入分析,但由于数据集的不同时空依赖性质,尚未确定最佳模型。在这项研究中,长期预测通常强调时间依赖性,而短期预测则更多依赖于空间依赖性。我们建议从Quadroplet-TSST模型开始,该模型适用于各种时空预测任务,并在所有数据集和配置中一致表现出色。
5 结论
在这篇论文中,我们介绍了PredFormer,这是一种为时空预测学习设计的无循环、无卷积模型。我们的深入分析扩展了对时空Transformer分解的理解,超越了现有的视频ViT框架。通过严格的实验,PredFormer展示了无与伦比的性能和效率,大幅超越了以前的模型。我们的结果揭示了几个关键见解:
- 交错时空Transformer架构建立了新的基准,跨多个数据集表现出色。
- 因子化的时间优先编码器显著优于全时空注意力编码器和因子化的空间优先配置。
- 实施dropout和统一的随机深度机制可以显著提升在过拟合数据集上的性能表现。
- 在所有基准中,绝对位置编码一直优于可学习的替代方案。
我们相信PredFormer不仅将为现实世界的应用建立一个强大的基准,而且还为未来基于纯Transformer的时空预测模型的创新铺平了道路。
A 附录
A.1 问题定义
时空预测学习旨在通过预测未来帧来学习空间和时间模式,基于过去的观测数据。给定帧序列 X t : T = { x t − T + 1 i } t − T + 1 t X^{t:T} = \{x^i_{t-T+1}\}^t_{t-T+1} Xt:T={xt−T+1i}t−T+1t,它封装了从时间 t t t 开始的最后 T T T 帧,目标是预测从时间 t + 1 t+1 t+1 开始的未来 T ′ T' T′ 帧,即 Y t + 1 : T ′ = { x t + 1 i } t + 1 t + 1 + T ′ Y^{t+1:T'} = \{x^i_{t+1}\}^{t+1+T'}_{t+1} Yt+1:T′={xt+1i}t+1t+1+T′。输入和输出序列表示为张量 X t : T ∈ R T × C × H × W X^{t:T} \in \mathbb{R}^{T \times C \times H \times W} Xt:T∈RT×C×H×W 和 Y t + 1 : T ′ ∈ R T ′ × C × H × W Y^{t+1:T'} \in \mathbb{R}^{T' \times C \times H \times W} Yt+1:T′∈RT′×C×H×W,其中 C C C、 H H H 和 W W W 分别表示通道、高度和宽度。 T T T 和 T ′ T' T′ 分别为输入和输出帧的数量。为简洁起见,以下部分我们用 X X X 和 Y Y Y 分别表示 X t : T X^{t:T} Xt:T 和 Y t + 1 : T ′ Y^{t+1:T'} Yt+1:T′。
通常情况下,我们采用配备了可学习参数 F Θ \mathcal{F}_{\Theta} FΘ 的深度模型来进行未来帧预测。通过求解以下优化问题可以获得最优参数集 Θ ∗ \Theta^* Θ∗:
Θ ∗ = arg min Θ L ( F Θ ( X ) , Y ) \Theta^* = \arg \min_{\Theta} \mathcal{L}(\mathcal{F}_{\Theta}(X), Y) Θ∗=argΘminL(FΘ(X),Y)
其中, L \mathcal{L} L 是测量预测结果与真实值之间差异的损失函数。
A.2 实验设置
在Moving MNIST实验中,200个epoch使用patch大小为4的情况下,由于内存限制,对于完整的注意力模型我们使用批量大小为2,而其他变体则使用批量大小为8。在相同patch大小下进行2000个epoch的实验时,我们将批量大小增加到16,使用单个48GB的A6000 GPU。在使用patch大小为8的实验中,我们在所有运行中保持批量大小为16,使用24GB的GPU。在Moving MNIST实验中,我们为所有PredFormer变体使用了24个GTB模块,具体包括6个Quadroplet-TSST层、8个Triplet-TST层和12个Binary-TS层。
对于TaxiBJ和WeatherBench数据集,我们为Triplet变体使用了6个GTB模块,而为其他变体使用了8个GTB模块。
A.3 更多可视化
图6展示了WeatherBench上的可视化。随着帧数的增加,TAU的误差比我们的误差增加得更显著。这展示了我们PredFormer模型在长期预测中的优势。
图7(a)和(b)显示了同一时间步的流入和流出。在这种情况下,第四帧显示的交通流量明显低于之前的帧。由于卷积神经网络(CNNs)的归纳偏差,TAU继续预测高流量水平。相反,我们的PredFormer通过准确捕捉这一突变展示了出色的泛化能力。这种能力突显了PredFormer在处理极端情况下的潜力,特别是在交通流预测和天气预测等应用中非常有价值。
图7:TaxiBJ流入和流出的可视化。我们放大了误差以便更好地进行比较。