Mamba.py: 状态空间模型的并行扫描

本文介绍了Mamba.py中的并行扫描算法,特别是其对Blelloch算法的Python实现。文章讨论了扫描的概念,将其与外部变量的更新和内部状态的调整联系起来,并详细解释了如何通过并行计算加速累加求和,如矩阵求和,以及up-sweep和down-sweep的过程。此外,还提到了selective_scan在状态空间模型中的应用。
摘要由CSDN通过智能技术生成

Mamba.py:扫描和并行扫描

mamba.py/docs/pscan.ipynb at main · alxndrTL/mamba.py (github.com)

是Mamba.py作者对其编写的pscan即并行扫描的解释。pscan是Belloch算法的pytorch实现,即并行前缀和扫描,Belloch算法可以参照Understanding the implementation of the Blelloch Algorithm (Work-Efficient Parallel Prefix Scan) | by Shivam Mohan | Nerd For Tech | Medium以及Prefix Sums and Their Applications(cmu.edu)

但是根据作者所说,因没有应用原始实现的recomputation技术,因此会占用巨大的显存,所以仅作教学用途。
Mamba.py的训练速度见下图
在这里插入图片描述

什么是扫描

一个扫描定义为一个操作,把一个矩阵作为输入,产生一个矩阵作为输出。

我个人理解:

系统的输入或者说外部环境随时间变化,而我们系统也要随之不断更新,不断将这些变量“扫描”进去,因为我们只能处理离散信息,所以我们会有一个采样步长 Δ \Delta Δ,我们根据这些外部信息更新我们所需要的量,比如我们的输出量,控制量。

扫描是外部变量的扫描,也是内部变量的扫描,所以在我看来,扫描的同义词是更新,扫描就是根据输入更新状态空间模型的参数和输出。

什么是并行扫描

下面是一个因果卷积网络

输出依赖于之前时刻的输入,因此按顺序输入也一定按顺序输出,我们不可能在还可能有输入的情况下,就盖棺定论或者未卜先知得到下一个时刻的输出。

但在输入确定的情况下,例如我们只有6个点的输入,我们不需要再按顺序一个个计算,因为不确定的是输入,而不是输入与输出的关系,只要输入确定,我们没有必要对输入做不必要地等待,我们可以并行计算输出,我们可以并行的计算上面橙色的输出。
在这里插入图片描述

累加计算的例子

一个简单著名的扫描的例子是一个矩阵的累加求和。

X = torch.tensor([1, 2, 3, 4])

torch.cumsum(X, dim=0)

最简单的一个实现是for循环

Y = torch.zeros_like(X)

cumulative_sum = 0
for t in range(X.size(0)):
    cumulative_sum += X[t]
    Y[t] = cumulative_sum

我们使用了一个累加变量cumulative_sum

一个等价形式如下

Y = torch.zeros_like(X)

Y[0] = X[0]
for t in range(1, X.size(0)):
    Y[t] = Y[t-1] + X[t]

我们不再显式表达累加变量,但它实际在Y里面。

表达为了递归形式 Y [ t ] = Y [ t − 1 ] + X [ t ] Y[t] = Y[t-1]+X[t] Y[t]=Y[t1]+X[t]

在这里插入图片描述

有点像是RNN的形式,从某种角度来说,Y相当于隐状态,X相当于输入,当我们处理输入时,我们不断更新隐状态。

我们看到这种计算方法时顺序循环,又没有可能并行扫描操作。

矩阵求和简化

更进一步简化我们的目标为计算输入矩阵X的总和。

在这里插入图片描述

可以写成一个树状结构

L = 2 d L = 2^{d} L=2d则我们可以通过两两求和将总的计算次数由 L L L变为 d d d次顺序求和,如图,原有7次加法8个阶段,可以变成7次加法3个阶段,每个阶段内的加法可以同时进行。

矩阵求和python实现

X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array
L = X.size(0)

Xa = X

for k in range(int(math.log2(L))):
    T = 2 * (Xa.size(0) // 2)

    # 两两分成一对
    Xa = Xa.view(T//2, 2) 

    # 每一对的两个元素相加
    Xa[:, 1].add_(Xa[:, 0])

    # 更新Xa
    Xa = Xa[:, 1]

因为Xa是X的一个view,因此我们实际上在原地进行更新,更新完毕的X实际为[1, 3, 3, 10, 5, 11, 7, 36]

但我们将每次循环中不更新的部分去掉,变成一个树状结构。

累加求和的并行

在这里插入图片描述

我们看到一些节点值是输入矩阵的部分和,例如10是1到4的和。每个节点实际上是它们的子节点或子孙节点的和。

因此我们可以利用这些节点值来计算,如果一个节点是左节点,我们将他们与其兄弟右节点的左子节点相连,将相加得到的值更新到所连接的后代节点。由上而下更新

在这里插入图片描述

但是需要注意的是,我们从始至终没有开辟过新的空间,在原地进行值的更新。因此最下层,即真正的X为如下值,而红框圈出的部分实际并不是sum(3->5),因为他所加的值,即X[3]并不是上层即中间过程中的值即3到4的值,而是最后更新的值,即1到4的和,因此应更正为1->5的求和,至此,所有的累加和均已求出

在这里插入图片描述

因此整个过程分为两部分,一部分是up-sweep,即从下往上,以求整个矩阵总和的形式原地更新矩阵值,而原地操作,使得矩阵中的值实际上是不同层次的,再通过down-sweep,从上到下,利用已更新的值更新完剩下的值

总结下来:

  • 首先向上扫描,把顶部元素,即总和作为根
  • 将当前节点值赋值给右节点,而当前节点的左节点值,是当前节点值减去UP树对应右节点的值
  • 重复

在这里插入图片描述

例如,我们得到28,在最右边节点首先我们得到前八个数的和36,要得到前七个,那么就是减去第八个,即原来这个节点的值为8,36-8 = 28

Blelloch 算法

Blelloch前缀求和,和累加求和的区别是,前缀和不包括自身的值。

Up-sweep

两两成对求和

在这里插入图片描述

在这里插入图片描述

Down-sweep

在这里插入图片描述

在这里插入图片描述

最后在第七步我们得到前缀和

selective_scan

对于状态空间模型的扫描,主要更新的是隐状态,得到隐状态即可得到我们的输入。

通过pscan函数得到隐状态hs(B, L, ED, N)

 def selective_scan(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
        
        hs = pscan(deltaA, BX) #隐状态x

        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y

deltaA实际为 e Δ A e^{\Delta A} eΔA BX实际为 Δ B u \Delta B u ΔBu

deltaA重定义为Aa, BX重定义为Xa
x k = e Δ k A x k − 1 + Δ k B ⋅ u k \begin{aligned} x_k=e^{\Delta_{k}\boldsymbol A}x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k} \end{aligned} xk=eΔkAxk1+ΔkBuk

x k = A a   x k − 1 + X a \begin{aligned} x_k=Aa\ x_{k-1}+Xa \end{aligned} xk=Aa xk1+Xa
具体实现见pscan.py

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值