论文阅读:SimVP: Simpler yet Better Video Prediction

论文地址:arxiv

摘要

作者认为,现有的CNN,RNN,Transformer 之类的视频预测领域的模型都过于复杂了,作者想要找到一个简单的方式,同时可以达到与之相当的效果。

作者提出了 SimVP,这是一个简单的视频预测模型,完全基于 CNN 构建,通过均方误差(MSE)损失函数以端到端的方式进行训练。在不引入任何额外技巧与复杂策略的情况下,就可以实现最先进的性能。

正文

深度视频预测模型当前主要有 4 类,如图所示:

分别是:

  1. RNN-RNN-RNN
  2. CNN-RNN-CNN
  3. CNN-ViT-CNN
  4. CNN-CNN-CNN

在纯 CNN 基础模型方面,要提高准确度,通常要使用各种技术,但是作者探索出了一个简单模型的新高度。

问题描述

给定一个在时间 t t t 的包含过去 T T T 帧的视频序列 X t , T = { x i } t − T + 1 t X_{t,T}=\{x_{i}\}^t_{t-T+1} Xt,T={xi}tT+1t,而目标是在时间 t t t 预测未来的序列 Y t , T ′ = { x i } t t + T ′ Y_{t,T'} = \{x_{i}\}^{t+T'}_t Yt,T={xi}tt+T。该序列包含接下来的 T ′ T' T 帧,其中 x i x_{i} xi 是一个具有通道数 C C C,高度 H H H 和宽度 W W W 的图像。形式上,预测模型是一个映射 F Θ : X t , T − > Y t , T ′ F_\Theta:X_{t,T}->Y_{t,T'} FΘ:Xt,T>Yt,T,其中的可学习参数 Θ \Theta Θ 通过以下公式优化:

Θ ∗ = arg ⁡ min ⁡ Θ L ( F Θ ( X t , T ) , Y t , T ′ ) \Theta ^* = \arg \min _{\Theta } \mathcal {L}(\mathcal {F}_{\Theta }(\boldsymbol {X}_{t, T}), \boldsymbol {Y}_{t, T'}) Θ=argΘminL(FΘ(Xt,T),Yt,T)
L L L 可以是各种损失函数。

模型架构

SimVP 由一个编码器,一个翻译器,一个解码器组成。

  • 编码器用于提取空间特征
  • 翻译器学习时间演变
  • 解码器则整合时间信息以预测未来帧

编码器

编码器堆叠了 N s N_s Ns 个 ConvNormReLU 块(Conv2d+LayerNorm+LeakyReLU)来提取空间特征,即在 (H,W)上进行 C 通道的卷积。隐藏特征表示为:

z i = σ ( L a y e r N o r m ( C o n v 2 d ( z i − 1 ) ) ) , 1 ≤ i ≤ N s z_{i} = \sigma (\mathrm {LayerNorm} (\mathrm {Conv2d}(z_{i-1}))), 1 \leq i \leq N_s zi=σ(LayerNorm(Conv2d(zi1))),1iNs

其中输入 z i − 1 z_{i-1} zi1 ​和输出 z i z_i zi 的形状分别为 ( T , C , H , W ) (T, C, H, W) (T,C,H,W) ( T , C , H ^ , W ^ ) (T, C, \hat{H}, \hat{W}) (T,C,H^,W^)

翻译器

翻译器使用 N t N_t Nt 个 Inception 模块来学习时间演变,即在 ( H , W ) (H, W) (H,W) 上进行 T ∗ C T*C TC 通道的卷积。

Inception 模块由一个 1*1 大小的 Conv2d 后接并行的 GroupConv2d 操作符完成。隐藏特征表示为:
z j = I n c e p t i o n ( z j − 1 ) , N s < j ≤ N s + N t z_{j} = \mathrm {Inception}( z_{j-1} ), N_s < j \leq N_s+N_t zj=Inception(zj1),Ns<jNs+Nt
其中输入 z j − 1 z_{j-1} zj1 和输出 z j z_j zj 的形状分别为 ( T ∗ C , H , W ) (T*C, H, W) (TC,H,W) ( T ^ ∗ C ^ , H , W ) (\hat{T}*\hat{C}, H, W) (T^C^,H,W)

解码器

解码器使用 N s N_s NsunConvNormReLU 块(ConvTranspose2d+GroupNorm+LeakyReLU)来重建真实帧,在(H, W)上进行 C 通道的卷积。隐藏特征表示为

z k = σ ( G r o u p N o r m ( u n C o n v 2 d ( z k − 1 ) ) ) , N s + N t < k ≤ 2 N s + N t z_{k} = \sigma (\mathrm {GroupNorm} (\mathrm {unConv2d}(z_{k-1}))),\\ N_s+N_t < k \leq 2N_s + N_t zk=σ(GroupNorm(unConv2d(zk1))),Ns+Nt<k2Ns+Nt

其中,输入 z k − 1 z_{k-1} zk1 与输出 z k z_k zk 的形状分别为 ( T , C ^ , H ^ , W ^ ) (T,\hat{C},\hat{H}, \hat{W}) (T,C^,H^,W^) ( T , C , H , W ) (T,C,H,W) (T,C,H,W)。使用 ConvTransposed2d 作为 unConv2d 操作符。

模型评估

使用均方误差(MSE)、平均绝对误差(MAE)、结构相似性指数(SSIM)和峰值信噪比(PSNR)来评估预测质量。

在五个数据集上进行实验,从而来进行评估,如下所示:

性能评估

可以看到,SimVP、PhyDNet和CrevNet显著优于先前的方法,MSE降低达到42%。然而,SimVP比PhyDNet和CrevNet简单得多,没有使用RNN、LSTM或复杂模块。

训练时间

由上可知,SimVP 的训练过程比其他方法快得多,所以 SimVP 可以更容易地使用与扩展。

翻译器的使用

使用 RNN 和 Transformer 替换了 CNN ,再进行测试。使用了不同模型中的翻译器来,测试后得到以下结果:

可以得出结论:

  1. CNN和RNN在有限的计算成本下实现了最先进的性能。
  2. 如果模型容量足够,RNN在长期内收敛速度更快。
  3. CNN训练更稳健,在大学习率下不会剧烈波动。
  4. 在类似的资源消耗下,Transformer在我们的SimVP框架中没有优势。

评判能否到 SOTA 水平

SimVP可以在轻量级其次上达到 SOTA 结果。此外,与 PhyDNet 相比,SimVP 的训练时间更短。

可以看到在不同的数据集上有良好的泛化能力。

可以看到,SimVP 在灵活预测长度的情况下扩展良好。SimVP 达到了最新的性能。

消融实验

哪个架构的设计对于性能有关键的作用?

由上图 1-4 可知:空间UNet、时间UNet、分组卷积和分组归一化都能带来性能提升,其重要性排序为:分组卷积 > 分组归一化 ≈ 空间UNet ≈ 时间UNet。

卷积核对性能的影响

由上图 5-8 可知,随着核大小的增加,可以看到显著的性能提升。通过将模型 8 的隐藏维度加倍构建于模型 9,这种提高可以进一步增强。

编码器,转换器,解码器的角色

  • 转换器主要关注预测物体的位置和内容。
  • 解码器负责优化前景物体的形状。
  • 编码器可以通过空间UNet连接消除背景误差。
### 解决 Python 程序中的 `TypeError: cannot unpack non-iterable NoneType object` 错误 当遇到 `TypeError: cannot unpack non-iterable NoneType object` 这样的错误时,通常意味着尝试解包一个返回值为 `None` 的对象。这可能发生在函数调用的结果被期望是一个可迭代的对象(如列表或元组),但实际上该函数返回了 `None`。 对于 SimVP 中的 `main.py` 文件,在调用 `load_data` 函数时发生此异常,可以考虑以下几个方面来排查并解决问题: #### 1. 检查数据加载逻辑 确保 `load_data` 函数内部正确处理了所有情况下的输入参数,并总是返回预期类型的输出。如果某些条件下无法获取到有效数据,则应设置默认返回值而不是直接返回 `None` 或者抛出更具体的异常[^1]。 ```python def load_data(param): result = some_operation_that_may_fail() if not isinstance(result, (list, tuple)): raise ValueError("Data loading failed or returned unexpected type.") return result ``` #### 2. 验证外部依赖项 确认任何用于读取文件、网络请求或其他资源的操作都成功完成。特别是如果有涉及到数据库查询或者其他 I/O 操作的地方,要保证这些操作能够正常工作并且不会因为连接失败等原因而提前终止导致返回 `None`。 #### 3. 增加调试信息 为了更好地理解问题所在位置以及为何会得到 `None` 类型的数据,可以在适当的位置加入日志记录语句或者断点来进行逐步跟踪分析。这样可以帮助定位具体哪一步出现了偏差从而采取相应措施加以修正。 ```python import logging logging.basicConfig(level=logging.DEBUG) ... data = load_data(some_param) if data is None: logging.error('Failed to get valid data') else: item_a, item_b = data # 此处假设应该有两个元素组成的序列 ``` 通过上述方法之一或多者的组合应用,应当能有效地找到引发 `TypeError: cannot unpack non-iterable NoneType object` 的原因,并对其进行修复。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值