序列建模架构Mamba
记录一下刚学到的mambda的基本内容
大致理解
mambda是改进的RNN,其可以变换为RNN的递归生成形式也可以变换为CNN的并行生成形式。
在推理的时侯为O(n)复杂度。训练时,由于前缀和算法的牛逼,复杂度降到O(nlogn)?
1.
首先是一个简化的RNN形式的公式:
2. 但也可以变换为CNN
y3可以由x0,x1,x2,x3一起并行计算得出
3. 于是有了:
4.同时为了增加可学习性,让ABC等参数都可以随着输入而改变。
即用一个网络根据输入预测参数值
5. 无法用CNN方式进行训练,只能用RNN形式
而为了加速计算。则需要求下式,也就是像一个前缀和一样的东西。
那么可以用超强的前缀和并行算法来计算:
reference
[1] Mamba原理最通俗介绍火了,一文看懂“Transformer挑战者”两大主要思想!网友:年度最佳解读 - 量子位的文章 - 知乎 链接: link