文章目录
Mamba.py:扫描和并行扫描
mamba.py/docs/pscan.ipynb at main · alxndrTL/mamba.py (github.com)
是Mamba.py作者对其编写的pscan即并行扫描的解释。pscan是Belloch算法的pytorch实现,即并行前缀和扫描,Belloch算法可以参照Understanding the implementation of the Blelloch Algorithm (Work-Efficient Parallel Prefix Scan) | by Shivam Mohan | Nerd For Tech | Medium以及Prefix Sums and Their Applications(cmu.edu)
但是根据作者所说,因没有应用原始实现的recomputation技术,因此会占用巨大的显存,所以仅作教学用途。
Mamba.py的训练速度见下图
什么是扫描
一个扫描定义为一个操作,把一个矩阵作为输入,产生一个矩阵作为输出。
我个人理解:
系统的输入或者说外部环境随时间变化,而我们系统也要随之不断更新,不断将这些变量“扫描”进去,因为我们只能处理离散信息,所以我们会有一个采样步长 Δ \Delta Δ,我们根据这些外部信息更新我们所需要的量,比如我们的输出量,控制量。
扫描是外部变量的扫描,也是内部变量的扫描,所以在我看来,扫描的同义词是更新,扫描就是根据输入更新状态空间模型的参数和输出。
什么是并行扫描
下面是一个因果卷积网络
输出依赖于之前时刻的输入,因此按顺序输入也一定按顺序输出,我们不可能在还可能有输入的情况下,就盖棺定论或者未卜先知得到下一个时刻的输出。
但在输入确定的情况下,例如我们只有6个点的输入,我们不需要再按顺序一个个计算,因为不确定的是输入,而不是输入与输出的关系,只要输入确定,我们没有必要对输入做不必要地等待,我们可以并行计算输出,我们可以并行的计算上面橙色的输出。
累加计算的例子
一个简单著名的扫描的例子是一个矩阵的累加求和。
X = torch.tensor([1, 2, 3, 4])
torch.cumsum(X, dim=0)
最简单的一个实现是for循环
Y = torch.zeros_like(X)
cumulative_sum = 0
for t in range(X.size(0)):
cumulative_sum += X[t]
Y[t] = cumulative_sum
我们使用了一个累加变量cumulative_sum
一个等价形式如下
Y = torch.zeros_like(X)
Y[0] = X[0]
for t in range(1, X.size(0)):
Y[t] = Y[t-1] + X[t]
我们不再显式表达累加变量,但它实际在Y里面。
表达为了递归形式 Y [ t ] = Y [ t − 1 ] + X [ t ] Y[t] = Y[t-1]+X[t] Y[t]=Y[t−1]+X[t]
有点像是RNN的形式,从某种角度来说,Y相当于隐状态,X相当于输入,当我们处理输入时,我们不断更新隐状态。
我们看到这种计算方法时顺序循环,又没有可能并行扫描操作。
矩阵求和简化
更进一步简化我们的目标为计算输入矩阵X的总和。
可以写成一个树状结构
当 L = 2 d L = 2^{d} L=2d则我们可以通过两两求和将总的计算次数由 L L L变为 d d d次顺序求和,如图,原有7次加法8个阶段,可以变成7次加法3个阶段,每个阶段内的加法可以同时进行。
矩阵求和python实现
X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array
L = X.size(0)
Xa = X
for k in range(int(math.log2(L))):
T = 2 * (Xa.size(0) // 2)
# 两两分成一对
Xa = Xa.view(T//2, 2)
# 每一对的两个元素相加
Xa[:, 1].add_(Xa[:, 0])
# 更新Xa
Xa = Xa[:, 1]
因为Xa是X的一个view,因此我们实际上在原地进行更新,更新完毕的X实际为[1, 3, 3, 10, 5, 11, 7, 36]
但我们将每次循环中不更新的部分去掉,变成一个树状结构。
累加求和的并行
我们看到一些节点值是输入矩阵的部分和,例如10是1到4的和。每个节点实际上是它们的子节点或子孙节点的和。
因此我们可以利用这些节点值来计算,如果一个节点是左节点,我们将他们与其兄弟右节点的左子节点相连,将相加得到的值更新到所连接的后代节点。由上而下更新
但是需要注意的是,我们从始至终没有开辟过新的空间,在原地进行值的更新。因此最下层,即真正的X为如下值,而红框圈出的部分实际并不是sum(3->5),因为他所加的值,即X[3]并不是上层即中间过程中的值即3到4的值,而是最后更新的值,即1到4的和,因此应更正为1->5的求和,至此,所有的累加和均已求出
因此整个过程分为两部分,一部分是up-sweep,即从下往上,以求整个矩阵总和的形式原地更新矩阵值,而原地操作,使得矩阵中的值实际上是不同层次的,再通过down-sweep,从上到下,利用已更新的值更新完剩下的值
总结下来:
- 首先向上扫描,把顶部元素,即总和作为根
- 将当前节点值赋值给右节点,而当前节点的左节点值,是当前节点值减去UP树对应右节点的值
- 重复
例如,我们得到28,在最右边节点首先我们得到前八个数的和36,要得到前七个,那么就是减去第八个,即原来这个节点的值为8,36-8 = 28
Blelloch 算法
Blelloch前缀求和,和累加求和的区别是,前缀和不包括自身的值。
Up-sweep
两两成对求和
Down-sweep
最后在第七步我们得到前缀和
selective_scan
对于状态空间模型的扫描,主要更新的是隐状态,得到隐状态即可得到我们的输入。
通过pscan函数得到隐状态hs(B, L, ED, N)
def selective_scan(self, x, delta, A, B, C, D):
# x : (B, L, ED)
# Δ : (B, L, ED)
# A : (ED, N)
# B : (B, L, N)
# C : (B, L, N)
# D : (ED)
# y : (B, L, ED)
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
hs = pscan(deltaA, BX) #隐状态x
y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
y = y + D * x
return y
deltaA实际为 e Δ A e^{\Delta A} eΔA BX实际为 Δ B u \Delta B u ΔBu
deltaA重定义为Aa, BX重定义为Xa
x
k
=
e
Δ
k
A
x
k
−
1
+
Δ
k
B
⋅
u
k
\begin{aligned} x_k=e^{\Delta_{k}\boldsymbol A}x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k} \end{aligned}
xk=eΔkAxk−1+ΔkB⋅uk
即
x
k
=
A
a
x
k
−
1
+
X
a
\begin{aligned} x_k=Aa\ x_{k-1}+Xa \end{aligned}
xk=Aa xk−1+Xa
具体实现见pscan.py