Flash-Attention

这是一篇硬核的优化Transformer的工作。众所周知,Transformer模型的计算量和储存复杂度是 O ( N 2 ) O(N^2) O(N2) 。尽管先前有了大量的优化工作,比如LongFormer、Sparse Transformer、Reformer等等,一定程度上减轻了Transformer的资源消耗,但对Transformer的性能有所折损,且扩展性不强,不能泛化到其它领域、以及复杂结构的叠加。

这篇工作从底层对Transformer的计算和读写进行了优化,主要有三个贡献:

  1. 加速了模型计算:现在GPU的计算速度已经远远超过了内存读写速度(模型计算速度慢是因为IO慢,而不是 O ( N 2 ) O(N^2) O(N2)的原因导致。也就是说transformer的瓶颈在IO,而不是运算),当GPU完成计算后,内存却还在读取数据,造成GPU闲置而内存繁忙读(消费者早就消费完了,生产者还在缓慢生产)的现象,也就是内存墙问题。FlashAttention通过tiling和算子融合计算,将复杂操作放到SRAM中计算,并减少从HBM读取次数,加快了模型计算速度。而之前的工作虽然减少了Transformer的计算复杂度,却并没有减少模型计算时间。
  2. 节省了显存:FlashAttention通过引入全局统计量,避免实例化大注意力矩阵,减少了显存占用。
  3. 精确的注意力:FlashAttention从底层优化了Transformer的计算,但是任务指标上没有任何折损,与普通的Transformer结果是完全等价。

现代GPU内存分级

architecture

flash attention的思路就是尽量地在SRAM中进行分块计算、算子融合,减少对HBM(即常说的显存)的读写,从加快模型计算,减轻内存墙问题。

算法流程

algorithm

tiling分块计算

# ---------------------
# Tc: K和V的分块数
# Tr: Q的分块数量
# ---------------------
for 1 <= j <= Tc:
    for 1 <= i <= Tr:
        do....

loop

由于对 Q , K Q, K Q,K矩阵进行了分块,就无法进行全局归一化。我们的最终目的是得到 O O O ,作者这里根据公式推导,不断用当前最新的rowmax和rowsum去更新,直到遍历完最后一块,最终结果就和标准场景下的结果完全一致。

计算量和显存分析

  1. 计算量: O ( N 2 d ) O(N^2 d) O(N2d),跟标准attention计算一致

computation

  1. 显存: m ∈ R N , l ∈ R N m \in R^N, l \in R^N mRN,lRN

gpu memory

IO复杂度分析

  1. 标准attention

standard

  1. flash-attention

flash

可以看到,flash-attention通过算子融合、分块计算减少了IO,内存墙问题得以缓解。


参考

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值