内核融合:
1.将O(BLD+DN)字节的 (Δ,𝑨,𝑩,𝑪) 从慢速HBM读取到快速SRAM
2.在SRAM中离散化后,计算大小为(B,L,D,N)的,
3.采用并行关联扫描,在SRAM中产生大小 (B,L,D,N)的中间状态
4.与C做乘法和求和,计算得到大小为(B,L,D)的最终结果,并将最终结果写入HBM
上述操作使得IOs减少了 𝑂(𝑁) (状态维度) 的因子,这实际上将操作速度提高了20-40倍
重新计算:
在内核融合的过程中,为防止内存爆炸,我们不保存大小为(B,L,D,N)的的中间状态。然而,这些中间状态对于计算梯度的反向传播是必要的。故我们在反向传播中重新计算这些中间状态。
由于从HBM读取到SRAM的输入 Δ,𝑨,𝑩,𝑪 和输出梯度的大小为 𝑂(𝐵𝐿𝑁+𝐷),并且输入梯度的大小也为 𝑂(𝐵𝐿𝑁+𝐷𝑁),因此重新计算避免了从HBM读取 𝑂(𝐵𝐿𝑁𝐷) 个元素的成本。这意味着,与存储它们并从HBM读取它们相比,在反向传播中对SSM状态进行重新计算加速了计算。
文章使用重新计算来优化整个选择性SSM块 (扫描操作、输入投影、卷积、激活、扫描、输出投影) 的内存要求。
这种方法不保存占用大量内存但能够快速进行重新计算的中间内容(例如输出激活函数或短卷积)。