提出了VMRNN单元,这是一种结合了Vision Mamba模块和LSTM优点的新型循环单元。经过广泛的评估,我们的方法在多个基准测试中表现出色,同时保持了较小的模型规模。
CNN的限制:处理输入数据时,单个卷积层只能捕捉到局部的、有限范围的特征信息;卷积核的大小决定了每次卷积操作能够“看到”的输入区域,即感受野
解决办法:增加网络深度、池化层、跳跃连接----->训练难度、计算成本、内存需求增加
ViT的限制:自注意力机制的计算复杂度随着输入序列的长度呈二次增长;
自注意力机制:让每个输入元素与其他所有元素进行交互,计算它们之间的相关性,捕捉全局信息,计算成本大幅上升
Mamba架构:长序列建模能力表现出色,线性复杂度,
VMRNN单元:结合了Vision Mamba模块和LSTM优点的新型循环单元
构建了一个以VMRNN单元为中心的网络,有效解决时空预测任务
1. Introduce
时空预测:降水预测、自动驾驶、交通流量预测、人类运动预测、表示学习
问题:时空数据的复杂物理交互和不可预测特性给仅依赖数据驱动的深度学习方法带来了重大障碍,难以精准预测
时空预测学习的核心:深入探索物理世界中固有的空间相关性、时间进展,在视频层面上从过去帧预测未来帧
现有方法:
CNN或ViT与RNN结合的递归方法:ConvLSTM、PredRNN、MIM、E3DLSTM...
非递归方法:SimVP、TAU
现有方法的问题:
CNN受限于局部感知野,限制了其从远处图像区域吸收信息的能力;
ViTs表现更好,归因于全局感知野和通过注意力机制实现的动态权重,但平方复杂度
好一点的新模型:
结构状态空间模型SSM:建模广泛序列方面表现出高效性和有效性
Mamba:解决各类任务中的长程依赖,引入输入序列的选择性并采用扫描算法,线性
文章提出的模型:VMRNN递归单元:将视觉Mamba模块与LSTM结合,以有效提取时空表示
开发了一种以VMRNN为中心的模型,专门用于识别对时空预测至关重要的空间和时间动态
模型在图像级处理每一帧,将其分割为补丁,并将这些补丁在传入到补丁嵌入层进行初步处理之前将其扁平化
模型继承了基于递归方法的特性
模型的VMRNN层利用这些变换后的补丁和之前的状态,捕捉时空表示以进行下一个预测
启发思路:Mamba强大的序列建模能力、线性复杂度
2.Relate Work
基于卷积CNN的架构:早期CNN与RNN相结合
ConvLSTM:引入卷积操作代替全连接操作,出金时空相互依赖关系的学习
PredRNN及其时空单元LSTM:能够同时水平和垂直传播隐藏状态来并行处理时空数据
PredRNN++:引入梯度高速公路单元,缓解前代模型遇到的梯度消失问题
E3D-LSTM:引入3D卷积增强ST-LSTM的记忆容量
MIM:双重递归单元重新设计了ST-LSTM的遗忘门,以更好地应对预测中非平稳信息的学习
CrevNet:使用基于CNN的可逆架构来解码复杂的时空模式
PhyDNet:将物理原理嵌入到CNN框架中,以提高预测质量
局部化操作限制了捕捉时空依赖性的能力
基于Transformer的架构:最初在NLP邻域,随后在CV领域也开始探索
Vision Transformer (ViT):首次直接将Transformer架构应用于图像分类
Swin Transformer:创新的移动窗口策略和分层结构,在图像分类、语义分割、目标检测等多个任务取得了显著成果
SwinLSTM:将Swin Transformer与LSTM结合,建立了新的稳健的时空预测基准
注意力机制在图像大小上的二次复杂性,增大了计算复杂度
状态空间模型:
SSM:最近被引入到深度学习中,作为状态空间转换模型
LSSL:受控制系统中连续状态空间模型的启发,结合HiPPO初始化,展现了处理长距离依赖的潜力
状态表示带来的巨大计算和内存需求,LSSL在实际应用中不可行
S4:将参数归一化为对角结构,解决上述问题
Mamba:提出选择性扫描空间状态序列模型模块,有选择性的处理输入序列并采用扫描算法
3.Method
整体架构:
VMRNN-B:仅有一个VMRNN单元的基础模型
VMRNN 层处理嵌入的图像块、前一时间步的隐藏状态 和细胞状态
,以生成当前的隐藏状态
和细胞状态
会复制生成两个版本:一个传递到重构层,另一个与 一起服务于下一个时间步的 VMRNN 层
对于 VMRNN-B,架构主要依赖于 VMRNN 层的堆叠
VMRNN-D:包含多个VMRNN单元的深度模型,
图块合并(Patch Merging)层:下采样,有效减少数据的空间维度,降低计算复杂度并捕获更抽象的全局特征
图块扩展(Patch Expanding)层:上采样,增加空间维度,促进细节的恢复,在重构阶段精确定位特征
重构层将VMRNN层的隐藏状态 恢复到输入大小,生成下一个时间步的预测帧
优势:
下采样简化了输入表示,是模型能以较少的计算开销处理更高级别的特征,利于理解数据的复杂模式和关系
上采样确保空间细节不会丢失
关键:
下采样与上采样之间的平衡,实现高质量预测,尤其在需要细粒度理解和视觉数据生成的任务中
VMRNN模块:图2(a)
VMRNN 模块移除了 ConvLSTM 中的所有权重 和偏置 ,得到方程:
通过更新细胞状态 捕捉长短期的时间依赖性,并从水平角度更新隐藏状态
VMRNN的关键公式:
VSB为VSS块,LP为线性投影
VSS块:图2(b)
输入---初始线性嵌入层---分为两个不同的信息流
第一个流---3*3的深度卷积层,该层带有Silu激活函数---SS3D模块---归一化层---与经过Silu激活函数的第一个流合并
VSS块复杂度比ViT更低,可以在相同模型深度约束下包含更多的块
VSS 块首先将线性投影恢复为图像形状,然后 VSS 块通过采用 2D 选择性扫描(SS2D)解决了与 2D 图像数据相关的挑战
给定输入特征 ,SS2D 的输出特征
可以表示为:
其中表示四个不同的扫描方向(图2(c))。expand(·) 和 merge(·) 对应于扫描扩展和扫描合并操作
VSS块的核心操作符:,使得1D数组中的每个元素能够压缩的隐藏状态与之前烧苗的样本进行交互
4.Experiments
损失函数:MSE
模型超参数:
评估指标:MSE、MAE、PSNR(峰值信噪比)、SSIM(结构相似性指数测量)
计算需求:参数数量、浮点运算数、每秒帧数
数据集:
Moving MNIST:评估序列预测模型的基准合成数据集
KTH:人类动作
TaxiBJ:北京出租车的GPS数据和气象数据
消融研究:
数据集:TaxiBJ
卷积层:VSS 块中的卷积层用于解码 VMRNN 单元提取的时空特征,DW 卷积(DW Conv)显著优于其他解码方法
图像块大小:图像块大小会影响输入序列的长度
VSS块数量:评估了 VSS 块数量对建模全局空间信息的影响,当 VSS 块数量为12时,模型在 MSE 和 SSIM 上表现最佳