论文笔记《Self-Attention ConvLSTM for Spatiotemporal Prediction》

1. Abstract

  • 为提取空间特征的全局和局部依赖性,本文向ConvLSTM引入了一个新的自注意力机制(self-attention mechanism)
  • 子注意力记忆模块(self-attention memory , SAM) 能在时空域记住那些具有长期依赖性的特征

2. Introduction

本文的创新点/贡献在于:

  • 提出一个新的基于ConvLSTM的变体模型用于时空预测,命名为SA-ConvLSTM,特点是能很好捕获长程空间依赖性
  • 设计了一个基于记忆的自相关模块(memory-based self-attention module, SAM),该模块用于在预测中记住全局的时空依赖性。
  • 为验证模型(1)使用MovingMNIST 和 KTH 数据集进行多框架预测;(2)使用TaxiBJ数。据集预测交通流量。本文模型优势是参数更少、效率更高。

写作思路:

  • 首段:(1)交待时空预测研究的重要性、现有研究(很简要),说明值得研究;(2)时空预测具有复杂动态性,时空领域都表现出依赖性。
  • 第二段:(1)ConvLSTM 效果不错;(2)存在问题1-长程依赖可以通过堆叠的卷积层捕获,但有效感受野要比理论上的感受野小很多;(3)存在问题2-离特殊位置较远的特征,要体现位置的影响 实现前馈和反向传播,就要经过很多层,这样一来训练时的优化就很困难;(4)现有的解决办法只能提供稀疏的依赖关系,估计的是局部感受野;(5)因此现有问题就是如何让ConvLSTM捕获到长程依赖性。
  • 第三段:(1)认为自注意力模块相对于卷积操作,更擅于获得全局空间上下文信息(注意:这里只是说普通的self-attention module),因此本文使用额外的记忆单元 M \mathcal{M} M ;(2) M \mathcal{M} M 也能像LSTM 通过门控机制捕获长程的时间依赖性。

3. Method

注:原文比较简略,下文按照自己的理解重新组织了顺序

3.1 模型整体结构

这篇文章创新点就是加了一个基于记忆的自相关模块(memory-based self-attention module, SAM),这个模块是接在ConvLSTM模型的最后的,如图浅绿色部分(如果没有它及其输出,这个图就是ConvLSTM模型图,或者说是LSTM模型图):
在这里插入图片描述

3.2 SAM模块

在这里插入图片描述
这个模块看上去好复杂,基于文章描述 它可以分为三个小部分,我在图上用不同色块标注出来(强迫症不允许色块对不齐,是不是很整齐hh):

  • 黄色区域:特征聚合,文章中的Feature Aggregation 部分
  • 蓝色区域:记忆更新,文章中的Memory Updating 部分
  • 绿色区域:输出,文章中的Output 部分

3.2.1 Feature Aggregation 特征聚合

在这里插入图片描述
整个黄色区域可分为两部分:

  • 上半部分(黄色):输入是当前时刻特征 H t \mathcal{H_t} Ht,经历一个普通的self-attention 模块,得到 Z h Z_h Zh
  • 下半部分(灰色):输入是上一时刻记忆 M t − 1 \mathcal{M}_{t-1} Mt1,也是经历一个self-attention 模块。不同的是,此处用的query Q Q Q 是当前时刻计算得到的,key K K K 是上一时刻 M t − 1 \mathcal{M}_{t-1} Mt1 计算得到的,通过 e = Q h T K h ∈ R N × N \mathbf{e}=\mathbf{Q}_{h}^{T} \mathbf{K}_{h} \in \mathbb{R}^{N \times N} e=QhTKhRN×N 计算相似性得分,然后再经过 softmax 将得分映射至 (0,1) 区间。最后再将得分与上一时刻记忆 M t − 1 \mathcal{M}_{t-1} Mt1 的值相乘,得到 Z m Z_m Zm
  • 通过通道相连将这两个输出拼一起,再乘权重,得到 Z Z Z
  • Z Z Z 再与当前时刻特征 H t \mathcal{H_t} Ht 拼接到一起,作为下一步骤的输入。

3.2.2 Memory Updating 记忆更新

在这里插入图片描述
感觉这部分记忆更新操作和GRU操作很像,具体操作如下,分两步走:

  • 通过tanh 对输入数据处理,将其映射到[-1,1]。 g t ′ = tanh ⁡ ( W m ; z g ∗ Z + W m ; h g ∗ H t + b m ; g ) g_{t}^{\prime}=\tanh \left(W_{m ; z g} * \mathbf{Z}+W_{m ; h g} * \mathcal{H}_{t}+b_{m ; g}\right) gt=tanh(Wm;zgZ+Wm;hgHt+bm;g)
  • 通过sigmoid 处理数据,将其映射到[0,1]上,形成gate。 i t ′ = σ ( W m ; z i ∗ Z + W m ; h i ∗ H t + b m ; i ) i_{t}^{\prime}=\sigma\left(W_{m ; z i} * \mathbf{Z}+W_{m ; h i} * \mathcal{H}_{t}+b_{m ; i}\right) it=σ(Wm;ziZ+Wm;hiHt+bm;i)
  • 最后更新记忆信息, M t = ( 1 − i t ′ ) ∘ M t − 1 + i t ′ ∘ g t ′ \mathcal{M}_{t}=\left(1-i_{t}^{\prime}\right) \circ \mathcal{M}_{t-1}+i_{t}^{\prime} \circ g_{t}^{\prime} Mt=(1it)Mt1+itgt

3.2.3 Output 输出

在这里插入图片描述
最后就是输出:

  • 先门控处理 o t ′ = σ ( W m ; z o ∗ Z + W m ; h o ∗ H t + b m ; o ) o_{t}^{\prime}=\sigma\left(W_{m ; z o} * \mathbf{Z}+W_{m ; h o} * \mathcal{H}_{t}+b_{m ; o}\right) ot=σ(Wm;zoZ+Wm;hoHt+bm;o)
  • 输出 H ^ t = o t ′ ∘ M t \hat{\mathcal{H}}_{t}=o_{t}^{\prime} \circ \mathcal{M}_{t} H^t=otMt

4. Experiments

4.1 Implementation

  • 设计为一个4层的网络,每一层有64隐藏层
  • ADAM optimizer
  • 初始学习率为0.001
  • 训练中的mini-batch=8,80000次迭代后收敛
  • MovingMNIST 和 TaxiBJ 数据集使用L2 loss
  • KTH数据集使用L1+L2 loss
  • 指标:SSIM(structural similarity Index Measure);MSE;MAE

4.2 Ablation Study

  • 标准的4层ConvLSTM
  • 只有self-attention的ConvLSTM模型
  • 只有additional memory cell M \mathcal{M} M 的ConvLSTM模型
  • 没有 Z m Z_m Zm 的SA-ConvLSTM 模型
  • 完整的SA-ConvLSTM 模型

在结果分析部分:
分析1:Ablation Study
在这里插入图片描述
分析2:不同模型之间的比较
在这里插入图片描述
在这里插入图片描述

分析3:MovingMNIST数据集的定性比较(用过去10帧预测未来10帧)
在这里插入图片描述
分析4: TaxiBJ数据集的定性比较(用过去4帧预测未来4帧,即两小时)
在这里插入图片描述
颜色越亮,绝对误差越高。虽然感觉不太方便对比,但是看上去很炫酷,有人知道这是用什么画出来的吗 🤔

最后还有两个图像识别任务的可视化结果,然后就是简短的总结,不翻译了~

  • 10
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 30
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 30
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值