flash attention1和2


欢迎关注微信公众号InfiniReach,这里有更多AI大模型的前沿算法与工程优化方法分享
请添加图片描述

attention

在这里插入图片描述
(1)(2)设计矩阵乘法,计算复杂度都是 O ( d N 2 ) O(dN^2) O(dN2)(仅乘法),复杂度随序列长度的增长呈二次方增长

GPU的显存即HBM带宽为1.5-2.0TB/s;片上SRAM的带宽约为19TB/s。在attention计算中涉及大量访存事务:

  1. 读QK,写S
  2. 读S,写P
  3. 读PV,写O

访存复杂度 O ( 2 d N + N 2 + N 2 + N 2 + N 2 + d N + d N ) = O ( d N + N 2 ) O(2dN + N^2 + N^2+N^2+N^2+dN+dN) =O(dN+N^2) O(2dN+N2+N2+N2+N2+dN+dN)=O(dN+N2)

flash attention 1

FlashAttention将优化重点放在了降低存储访问开销,为了减少对HBM的读写,FlashAttention将参与计算的矩阵进行分块送进SRAM,减少了HBM访存,来提高整体读写速度。

softmax时涉及计算每行元素最大值m(x),计算 ∑ i e x i − m ( x ) \sum_i e^{x_i - m(x)} iexim(x),将QK划分成块后,只能计算局部,设置长为N的全局max(x)和全局 ∑ i e x i − m a x ( x ) \sum_i e^{x_i - max(x)} ieximax(x),每次计算完局部后,更新这俩全局向量,更新的方法就是乘 e m a x ( x ) o l d − m a x ( x ) n e w e^{max(x)_{old}-max(x)_{new}} emax(x)oldmax(x)new,消掉原来的局部max。

在这里插入图片描述
在这里插入图片描述
外层循环KV,内层循环QO,计算内层访存复杂度,内存循环一次访存Nd,外层循环 N B c = 4 d N M \frac{N}{B_c}=\frac{4dN}{M} BcN=M4dN,共计 O ( N 2 d B − 1 ) = O ( N 2 d 2 M − 1 ) O(N^2dB^{-1})=O(N^2d^2M^{-1}) O(N2dB1)=O(N2d2M1)。 M(100KB)通常远远大于 d(数K),所以FlashAttention的MAC远小于标准的Transformer O ( d N + N 2 ) O(dN+N^2) O(dN+N2)

FlashAttention 的 向后传递需要SP矩阵来计算 Q,K,V 的梯度。然而由于空间复杂度是 O(N^2) ,没有显式存储。使用输出 O 和 softmax 归一化统计 (m,ℓ),利用SRAM中的 Q,K,V 重新计算 S 和 P 矩阵。这个过程使用更多的flop,由于减少HBM访问,重新计算也加快了反向传播的速度。

FlashAttention 的速度优化原理是怎样的? - Civ的回答 - 知乎
https://www.zhihu.com/question/611236756/answer/3132304304

flash attention 2

  1. 减少了non-matmul FLOPs的数量(消除了原先频繁rescale)。虽然non-matmul FLOPs仅占总FLOPs的一小部分,但它们的执行时间较长,这是因为GPU有专用的矩阵乘法计算单元,其吞吐量高达非矩阵乘法吞吐量的16倍。因此,减少non-matmul FLOPs并尽可能多地执行matmul FLOPs非常重要。

  2. 提出了在序列长度维度上并行化。该方法在输入序列很长(此时batch size通常很小)的情况下增加了GPU利用率。即使对于单个head,也在不同的thread block之间进行并行计算。

  3. 在一个attention计算块内,将工作分配在一个thread block的不同warp上,以减少通信和共享内存读/写。

  4. 原版设置长为N的全局max(x)和全局 ∑ i e x i − m a x ( x ) \sum_i e^{x_i - max(x)} ieximax(x),需要每次更新这俩。新版里先不计算分母,即softmax没分母,等最后再除以分母:

在这里插入图片描述

GPU的Tensor Cores FP16/BF16矩阵乘法的最大理论吞吐量为312 TFLOPs/s,但FP32非矩阵乘法仅有19.5 TFLOPs/s,即每个no-matmul FLOP比mat-mul FLOP昂贵16倍。

  1. FlashAttention-2还在序列长度这一维度上进行并行化,显著提升了计算速度。此外,当batch size和head数量较小时,在序列长度上增加并行性有助于提高GPU占用率。

FlashAttention在batch和heads两个维度上进行了并行化:使用一个thread block来处理一个attention head,总共需要thread block的数量等于batch size × number of heads。但是在处理长序列输入时,由于内存限制,通常会减小batch size和head数量,这样并行化成都就降低了。

  1. 原版外层循环是KV,内存循环是QO,O需要刷新外层循环的次数;所以把QO设为外层循环,KV设为内层循环。将FlashAttention的Q共享,KV切4个warp,改进为Q切4个warp,共享kv,warp之间不再需要通信,减少了share memory读写,带来了性能提升

https://zhuanlan.zhihu.com/p/645376942

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值