Mamba框架及其NLP领域优势 VS Transformer

概要

在这里插入图片描述

Mamba模型是一款内核为线性循环神经网络的新型序列模型,以优秀的计算量和语言建模能力力压Transformer。所以为什么它会被ICLR拒稿,这点作者也说不清楚,这也不是本文的重点,本文的重点是理清楚Mamba的优势和变化到底在哪里。

原文链接: https://arxiv.org/pdf/2312.00752
代码链接:https://github.com/state-spaces/mamba

Mamba框架

设计Mamba的动机

基础模型,现在为深度学习中大多数令人兴奋的应用提供动力,几乎都是基于Transformer架构及其核心注意力模块。许多次二次时间架构如线性注意力、门控卷积和循环模型以及结构化状态空间模型( SSMs )已经被开发,以解决Transformer在长序列上的计算效率低下(几何注意力层无法扩展)问题,但它们在语言等重要模态上的表现(长程依赖性不够或者预测的信息量不足)不如注意力。由此该团队发现这些模型的关键弱点是它们不能执行基于内容的推理,并提出了一些改进。
概括的来说,Mamba这个框架就是通过线性结构解决了Transformer计算效率问题,同时通过自己的改进又解决了线性结构本身具有的长程依赖能力不足等等核心问题,从而在自然语言处理(Natural Language Processing,NLP)建模任务上获得巨大提升。

Mamba结构流程

Alt
从总体架构来看,Mamba要比Transformer简单很多,组件也少一点,就是把选择性状态空间模型(Selective State Space Model, SSM)这种线性结构和深度学习中的非线性MLP结合到了一块,以此来适应NLP任务的特点。因此,其核心算子自然就是这个SSM模块。

SSM模块

SSM全称State Space Model,被译为状态空间模型,看似很复杂,实际上把他理解为一种动态递归线性循环神经网络就好。SSM有多个版本,结构流程图中展示的为S4版本,而实际上Mamba对此进行了改进,当前版本为S6,如下图所示。
在这里插入图片描述
在解释这样一个线性结构之前,就不得不提一下循环神经网络前辈RNN了。RNN的流程和上图及其相似,也是将一个输入乘以一个权重去加上上一个时间步长的信息,理论上来说这样是可以得到所有数据信息的。但是RNN这种简单的操作会导致两个较为致命的问题:①连续的项链一般的计算过程导致无法进行并行计算,从而训练过程及其缓慢;②简单的乘性计算方式导致梯度及其不稳定,不是越乘越大导致梯度爆炸,就是越乘越小导致学不到啥东西。因此抛开理论谈实际的话,RNN只能记住前面几个步长的信息。
那么就有了一个新的概念需要解释了:递归线性循环神经网络或者说递归循环神经网络(Resurrecting Recurrent Neural Networks1)。人们都知道RNN其实是个好东西,但是就是不好用,怎么办?最简单的解决办法就是:通过递归线性化和对角化、使用更好的参数化和初始化,以及确保前向传递的归一化,最终实现对循环网络权重的控制。说的简单一点,就是通过人为控制权重的范围,使得循环神经网络不会产生梯度问题,这样子它就可以真正稳如老狗一般的实现几乎无穷的长程依赖能力,根据一个长距离竞技场(Long Range Arena2)的文章提出的评价指标来看,这种递归线性循环神经网络在单纯的长程依赖能力方面是可以碾压Transformer的。但是又有一个新的问题,Mamba是做NLP任务的啊,而实践证明递归线性循环神经网络虽然可以记住很久之前的信息,但是NLP建模并不好。所以Selection SSM就诞生了,或者说一种动态递归线性循环神经网络就诞生了。
S6版本的SSM公式如下:
在这里插入图片描述
(1a)和(1b)表示的是连续空间状态方程,离散化后会得到离散的状态空间模型,即(2a)和(2b),具体的推导过程这里就不推了,涉及到稍稍复杂的数学公式。对应到流程图就是当前的输入 x t x_{t} xt会乘以权重 B ‾ \overline{B} B再加上上一个时间步长 x h − 1 x_{h-1} xh1乘以权重 A ‾ \overline{A} A从而得到当前的时间步长 x h x_{h} xh,当前的时间步长 x h x_{h} xh乘以一个权重 C C C就可以得到当前的一个输出值 y h y_{h} yh了。
在这个过程中权重 A ‾ \overline{A} A B ‾ \overline{B} B C C C都是经过递归和约束过范围的,因此具有极佳的长程依赖能力,意味着再长的序列它也能记住,但是前面说了NLP任务不一定好,且原文提出构建序列模型的一个基本原则是选择性:或者说上下文感知的关注或过滤输入到序列状态的能力。因此Mamba加了一个Selection Mechanism。具体的变化见下表:
在这里插入图片描述
在这里插入图片描述

由上表可以发现,几个主要的权重生成方式被完全改变。由原先的Parameter变为了 s B ( x ) s_{B}(x) sB(x) s C ( x ) s_{C}(x) sC(x)等等,而根据原文图示可以发现这些权重本质为 L i n e a r N ( x ) Linear_{N}(x) LinearN(x),也就是和输入相关的线性变换结果被设置为了可学习参数作为权重。这种看似简单的变化直接将SSM从时不变系统转型成为时变系统,因为当前的权重变化不再是一个不随时间变换的参数量,而是和随时间变换的输入 x x x相关的学习参数。至此,SSM不仅仅具有了超强的长程依赖能力,同时具备了可以针对NLP任务特点的时变筛选能力,这意味着我可以选择自己感兴趣且重要的信息去进行记忆,而抛弃那些不重要的,对于语序推理不利的特征。

Mamba超强计算性能

OK!说完Mamba的语言建模能力,那么它所谓的推理速度优势又是怎么体现的。
主要的原因有两点:
①线性模型具有的天生优势就是计算量比较低,计算量和输入尺寸呈现 l o g n log_{n} logn的关系。然后,就不得不提到Transformer的计算效率的问题,在训练过程中Transformer的计算效率还是很能打的,原因在于可以通过GPU并行加速实现快速训练,但是在测试过程中由于其计算量和输入长度呈现平方倍 n 2 n^{2} n2的关系,因此时间复杂度随参数量的增加呈现指数级增长,这就直接导致在进行大规模或者超算数据处理时需要投入更多资源。最直接的体现就是时间复杂度的提升,间接的体现就是数据吞吐量以及训练结果上限的问题。
②Mamba团队进行的底层GPU加速设计。具体的说是将他们在进行SSM模块运行设计时,没有在GPU HBM (高带宽存储器)中准备尺寸为( B,L,D,N)的扫描输入( A ‾ \overline{A} A B ‾ \overline{B} B),而是直接从慢速HBM加载SSM参数A、B、C到快速SRAM中,在SRAM中执行离散化和递归,然后将尺寸为( B,L,D)的最终输出写回HBM。这种操作直接增加了SSM的数据计算速度,也正是因为计算速度大幅提升,其数据搬运的时间复杂度反而成为了大头。
还有一点值得一提,虽然在标准的循环神经网络里,输入和输出以及处理过程尺寸要一样大,但是Mamba就格外贪心,它将输入向量扩大了n倍来进行数据计算,从而存储从过去获得的更加大量信息,之后这些向量会在输出之前倍缩回原本尺度再传递给下一层。虽然说这种操作会将计算时间复杂度也扩大n倍,但是相较于数据搬运成本,这些增加的计算时间还是杯水车薪几乎可以忽略,由此可见其计算效率的提升有多巨大。

Mamba实验效果

在这里插入图片描述
NLP建模任务,效果√。

在这里插入图片描述
在这里插入图片描述
医学DNA序列建模任务,效果√。

在这里插入图片描述
最离谱的扫描速度和吞吐量。左边高pytorch十几倍,右边吞吐量高Transformer有将近5倍,效果√。

小结

Mamba的很多具体细节需根据原文以及公式推导进行详细的分析,有需要的话还是推荐查看原论文。如果大家对此感兴趣,下一篇准备出一些Mamba在cv领域的一些应用。
以下是本次介绍的一些其他参考文献:


  1. Resurrecting Recurrent Neural Networks for Long Sequences ↩︎

  2. LONG RANGE ARENA: A BENCHMARK FOR EFFICIENT
    TRANSFORMERS
    ↩︎

  • 27
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值