#1024程序员节|征文#
HiPPO的学习暂告一段落,按照“HiPPO->S4->Mamba 演化历程”,接着学习S4。
《Efficiently Modeling Long Sequences with Structured State Spaces》
文章链接:https://ar5iv.labs.arxiv.org/html/2111.00396
https://arxiv.org/abs/2111.00396
摘要
文章提出了一种名为S4(Structured State Space sequence model)的序列模型,旨在有效处理长距离依赖(LRDs)。尽管现有的模型如RNN、CNN和Transformer等有专门变体来捕获长距离依赖,但它们在处理超过10000步的非常长序列时仍然存在困难。最近一种基于状态空间模型(SSM)的方法展示了通过适当选择状态矩阵A,可以在数学和实证上处理长距离依赖。然而,这种方法在计算和内存需求上成本过高,使其不适用于一般序列建模解决方案。S4模型通过新的参数化方法对SSM进行了改进,使得它在保持理论优势的同时,计算效率大大提高。该技术涉及对A进行低秩修正,使其可以稳定地对角化,并简化SSM为一个著名的Cauchy核的计算。S4在多个基准测试中取得了强大的实证结果,包括在无需数据增强或辅助损失的顺序CIFAR-10上的准确率达到91%,与更大的2-D ResNet相当,在图像和语言建模任务上与Transformer的差距显著缩小,同时生成速度快60倍,在Long Range Arena基准测试的每个任务上都达到了SoTA,包括解决了所有先前工作都失败的具有挑战性的Path-X任务,长度为16k,同时与所有竞争对手一样高效。
解决的主要问题
- 序列建模中长距离依赖(LRDs)的有效处理。
- 现有模型在处理非常长序列时的局限性。
- 先前基于SSM的方法在计算和内存需求上的高成本问题。
方法
- 提出了S4模型,它基于SSM的新参数化,通过将状态矩阵A分解为低秩和正常项的和,使得A可以被稳定地对角化。
- 利用Woodbury identity和Cauchy核的计算,将SSM的计算复杂度从O(N^2L)降低到O(N+L),其中N是状态维度,L是序列长度。
- 在多个任务和数据集上验证了S4模型的性能,包括顺序CIFAR-10、WikiText-103语言建模、图像分类和时间序列预测等。
介绍
序列建模的一个核心问题是有效地处理包含长距离依赖(LRDs)的数据。现实世界中的时间序列数据通常需要在数万步时间步长上进行推理,而很少有序列模型能够处理甚至成千上万步的数据。例如,长距离范围竞技场(LRA)基准测试的结果[40]凸显了当今序列模型在LRD任务上的表现不佳,包括一个(Path-X)没有任何模型的表现优于随机猜测的情况。
由于LRDs可能是序列模型面临的最大挑战,所有标准模型族,如连续时间模型(CTMs)、RNNs、CNNs和Transformers,都包括许多旨在解决它们的专门变体。现代的例子包括正交和Lipschitz RNNs[1, 13]来对抗梯度消失问题,扩张卷积来增加上下文大小[3, 28],以及日益庞大的高效Transformer家族,它们减少了对序列长度的二次依赖[8, 22]。尽管这些解决方案针对LRDs设计,但在LRA[40]或原始音频分类[18]等具有挑战性的基准测试上,它们的表现仍然不佳。
最近基于状态空间模型(SSM)的一种替代方法被提出来解决LRDs问题(图1)。SSM是在控制理论、计算神经科学等领域中使用的基础科学模型,但由于具体理论原因,它们并未适用于深度学习。特别是,Gu等人[18]表明,深度SSMs即使在简单任务上也表现不佳,但如果配备了最近推导出的特殊状态矩阵A来解决连续时间记忆问题[16, 45],则可以表现出色。他们的线性状态空间层(LSSL)在概念上统一了CTM、RNN和CNN模型的优势,并提供了一个概念验证,即深度SSMs在原则上可以解决LRDs问题。
不幸的是,由于状态表示引起的计算和内存需求过高,LSSL在实践中不可行。对于状态维度N和序列长度L,计算潜在状态需要O(N^2L)的操作和O(NL)的空间——与两者的Ω(L + N)的较低界限相比。因此,对于合理尺寸的模型(例如,Gu等人[18]中的N = 256),LSSL使用的内存比同样大小的RNN或CNN多几个数量级。尽管为LSSL提出了理论上高效的算法,但我们展示这些算法在数值上是不稳定的。特别是,特殊的A矩阵在线性代数意义上是非正常的,这阻止了传统算法技术的应用。因此,尽管LSSL展示了SSMs的强大性能,但它们作为一般序列建模解决方案在计算上目前是不切实际的。
在这项工作中,我们提出了基于SSM的结构化状态空间(S4)序列模型,它解决了之前工作中的关键计算瓶颈。技术上,S4重新参数化了Gu等人[16]和Voelker等人[45]中出现的结构化状态矩阵A,通过将它们分解为低秩和正常项的和。此外,我们不是在系数空间中展开标准SSM,而是计算其在频率空间中的截断生成函数,这可以简化为类似多极子的评估。结合这两个想法,我们展示了低秩项可以通过Woodbury恒等式进行校正,而正常项可以稳定地对角化,最终简化为一个得到了很好研究并且理论上稳定的Cauchy核[29, 30]。这导致计算和内存使用都是˜O(N + L),这对于序列模型来说是基本的。与LSSL相比,S4在速度上快了30倍,内存使用减少了400倍,同时在实证上超过了LSSL的性能。
迈向通用序列模型
除了LRD,机器学习的一个广泛目标是开发一个可以用于各种问题的单一模型。如今的模型通常专门用于解决特定领域(如图像、音频、文本、时间序列)的问题,并实现了一系列能力(如高效训练、快速生成、处理不规则采样数据)。这种专业化通常通过特定领域的预处理、归纳偏差和架构来表达。序列模型提供了一个通用的框架,可以通过减少专业化来解决许多这些问题,例如用于2D信息较少的图像分类的Vision Transformers[12]。然而,大多数模型(如Transformers)通常仍需要对每项任务进行大量专业化,以实现高性能。
深度SSM尤其具有概念优势,表明它们可能是一种有前景的通用序列建模解决方案。这些优势包括处理LRD的原则性方法,以及在连续时间、卷积和循环模型表示之间移动的能力,每种表示都有不同的能力(图1)。我们的技术贡献使SSM能够以最小的修改成功应用于各种基准:
大规模生成建模。在CIFAR-10密度估计中,S4与最佳的自回归模型(2.85
bits/dim)具有竞争力。在WikiText-103语言建模上,S4基本上缩小了与Transformers的差距(在0.8困惑度以内),将SoTA设置为无注意力模型。
快速自回归生成。与RNN一样,S4可以利用其潜在状态在CIFAR-10和WikiText-103上执行比标准自回归模型快60倍的像素/令牌生成。
采样分辨率变化。与专用CTM一样,S4可以适应时间序列采样频率的变化,而无需重新训练,例如在语音分类中以0.5倍的频率。
学习时带有较弱的归纳偏见。在没有架构变化的情况下,S4在语音分类方面超越了语音CNN,在时间序列预测问题上超越了专门的Informer模型,并在顺序CIFAR上与2-D ResNet匹配,准确率超过90%。
文章翻译
下面是翻译工具对这篇文章的大概翻译,作为参考,放到这里。很多地方翻译的不到位,凑合看吧。