flash-attention123和ring-attention

1. flashattention-1

它的主要思想是希望把对于attention score matrix的计算拆成有规律的小块,以可控的顺序避免机器自己做无用的HBM <—> SRAM之间的读写。一个朴素的想法就是,既然q要对每个kv计算,那么我干脆只读一次kv,让q来回倒腾,比起倒腾kv,这样读写量更少。所以必须把KV放循环的外层。
实现flashattention-1时,最大的难点在于上溢,老式的attention会保留所有中间过程,所以总是在所有attention score出来之后,先减去max att-score再指数运算,就避免了上溢。而flashattention为了避免多次IO,必须舍弃中间过程,这就没法避免上溢了。一个聪明的做法是,由于上溢是没有除以max att-score引起的,而max att-score随着一个个att-score的出现,总是增加的,所以我完全可以在最终max att-score没出来之前,先让att-score减去一个当下最max的att-score,等下一个qk block计算完,如果max att-score被更新了,就对已经计算出来的前述(q*k)v再除以这其中的max的exp差值,同时对当下的(qk)*v应用当下更新完的max att-score权重,然后把当下和前述的加在一起。
这种方法多了许多max att-score更新导致的多余的除法操作,但是整体来看,由于减少了对q的反复读写,反而增加了速度。

2. flashattention-2

2对1的更新相对比较零散。
首先,作者提出,虽然你加速了att过程,但是GPU的吞吐量仍然很低。GPU吞吐量的意思可以这么理解:你要计算1+2+3+4,这四个数必须一个一个加,即使你作为高中生明明有很强的能力,你还是要一个一个加。这就是GPU吞吐量低的原因,GPU被设计为有利于矩阵运算,但是attention中还有大量的不能并行的非矩阵运算,这些运算导致GPU无法满负荷运转。

优化1. 调换Q和KV的位置、保存L而不是m和l

之前不是说flashattention1倒腾q以减少读写吗,到了2反过来了,外层循环q,内层循环kv。总之看两方论文的伪代码,两层for循环的位置,二者是调换的。不知道为什么。
但是另一个优化和它有关,即在2的伪代码13行。由于flashattention系列都不保留中间结果,所以在反向传播时还得算一次,而因为已经有了max att-score m和分母l,使得我们的计算能更简单一些,比起再次先算权重后除归一化因子,直接算子融合,让L替换m的位置,直接就能得到正确的归一化权重。
在这里插入图片描述

优化2.前向时,不需要提前除以归一化因子

我们实时更新max att-score,是为了防止计算分子时v的加权求和出现上溢。这实际上只需要管住权重就行,而不需要在当下就除以当下的归一化分母,否则如果下一个块时max att-score更新了,还得乘回来再除新的。我们在flashattention1中没有提到这个冗余的操作,估计很多人学的时候也没注意到。

优化3.略掉causal mask后恰好全被mask掉的qk块

这个很好理解,肯定是之前的mask操作是算完再mask,浪费时间了。

优化4. 并行

并行的前提是每个并行的计算之间不会互相干扰,例如batch之间、att head之间。GPU能支持的并行计算是固定的。
但是随着模型越来越大,单卡的batch越来越小,GPU支持的量好像比我们并行的量多了,那么也就是说有些GPU性能没用到,我们得想法子提高算法的并行度。
这时候突然想起来为什么flashattention-2要交换Q和KV的循环了。你在计算一个Q的attention score的时候,第二个Q和第一个Q完全无关,而如果你外循环是KV,你第二个KV就必须用到第一个KV留下来的值,例如m、l,这就把当前这组QKV的计算困在一个GPU并行单元里了。反之,如果外循环是Q,我完全可以拿前半组Q去一个单元里跑,后半组Q去另一个单元里跑。

前向已经被改成外圈Q内圈KV了,这也符合causal attention,第一个q会很快地退出计算,如果像flashattention-1一样,q1就要等到最后再计算完了,同时qn即使作为最后一个,即使计算量很小,也得等到前面的计算完再计算。到了反向传播,因为每个token只与一列attention有关,所以此时的循环改成外圈KV内圈Q。

对于Q分割还是KV分割,flashattention-1考虑的是尽量少地占用SRAM,所以让Q复制到每个warp里,KV按行分割成多份分散到warp里。这样做有个坏处,就是要同步更新所有warp的信息,例如max att-score。如果不考虑占用,肯定是分割Q、复制KV好。这与上一段一样,只不过是从thread block的程度细化到了warp的程度。建议先了解thread block线程块和warp线程的关系

  • 13
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值