Mamba中的硬件感知算法——并行关联扫描

内核融合:

1.将O(BLD+DN)字节的 (Δ,𝑨,𝑩,𝑪) 从慢速HBM读取到快速SRAM

2.在SRAM中离散化后,计算大小为(B,L,D,N)的\bar{A},\bar{B}

3.采用并行关联扫描,在SRAM中产生大小 (B,L,D,N)的中间状态

4.与C做乘法和求和,计算得到大小为(B,L,D)的最终结果,并将最终结果写入HBM

上述操作使得IOs减少了 𝑂(𝑁) (状态维度) 的因子,这实际上将操作速度提高了20-40倍

重新计算:

在内核融合的过程中,为防止内存爆炸,我们不保存大小为(B,L,D,N)的的中间状态。然而,这些中间状态对于计算梯度的反向传播是必要的。故我们在反向传播中重新计算这些中间状态。

由于从HBM读取到SRAM的输入 Δ,𝑨,𝑩,𝑪 和输出梯度的大小为 𝑂(𝐵𝐿𝑁+𝐷),并且输入梯度的大小也为 𝑂(𝐵𝐿𝑁+𝐷𝑁),因此重新计算避免了从HBM读取 𝑂(𝐵𝐿𝑁𝐷) 个元素的成本。这意味着,与存储它们并从HBM读取它们相比,在反向传播中对SSM状态进行重新计算加速了计算。

文章使用重新计算来优化整个选择性SSM块 (扫描操作、输入投影、卷积、激活、扫描、输出投影) 的内存要求。

这种方法不保存占用大量内存但能够快速进行重新计算的中间内容(例如输出激活函数或短卷积)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值