LLM 优化技术(4)——flash_attention 原理

论文地址:FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness

在这里插入图片描述

如上图所示,Flash-attention 采用了矩阵分块和算子融合(safe softmax reducetion)的方法,尽可能的减少内存的 IO 时间,最大化利用 GPU 硬件中的共享内存,减少去 HBM 中进行内存的搬运等操作。在长序列的 LLM 推理任务中,能够有效的减少推理的时间,尤其是 prefill 阶段。

1 计算复杂度分析

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax \left( \frac{QK^T}{\sqrt{d_k}} \right)V Attention(Q,K,V)=softmax(dk QKT)V

flashattention 计算注意力的 IO 复杂度为 O ( N 2 d 2 M − 1 ) O(N^2 d^2 M^{-1}) O(N2d2M1),其中 d 是 head_dimension ,M是 SRAM 的大小,标准注意力的复杂度 O ( N d + N 2 ) O(Nd + N^2) O(Nd+N2), 考虑到 M 的值相比 d 通常大很多,因此前者的复杂度远远小于后者,且 SRAM 越大越明显。

1.1 Naive Attention 内存加载流程
  • 从 HBM 加载 Q, K , 计算 S = Q K T S = QK^T S=QKT , 将 S 写回 HBM, N d + N 2 Nd + N^2 Nd+N2
  • 从 HBM 读取 S, 计算 P = s o f t m a x ( S ) P = softmax(S) P=softmax(S), 将 P 写回 HBM, 这一步 2 N 2 2N^2 2N2
  • 从HBM读取 P、V, 计算 O = P V O = PV O=PV, 将 O 写回到 HBM, 这一步的访问次数为 N 2 + 2 N d N^2 + 2Nd N2+2Nd,
  • 返回计算结果 O

总的注意力访问次数为 3 N d + 4 N 2 3Nd + 4N^2 3Nd+4N2, 5 次 HBM 读取,3 次 HBM 写入。

FlashAttention 的总访问次数为 N 2 d 2 M \frac{N^2 d^2}{M} MN2d2

2 Flash-attention 流程

flash attention 将 online-softmax 和矩阵分块结合起来计算 attention,将本来不能分块的 row 可以拆分成多个更细粒度的 Block。

算法思想:

  • 减少 HBM 的访问,将 QKV 切分为小块后放入 SRAM 中
  • tiling, recomputation
2.1 tiling(平铺): 分块计算

因为 Attention 计算中涉及 Softmax,所以不能简单的分块后直接计算。

softmax 计算方法:

s o f t m a x ( x j ) = e x j ∑ i = 1 k e x i softmax(x_j) = \frac{e^{x_j}}{\sum _{i=1} ^k e^{x_i}} softmax(xj)=i=1kexiexj

softmax 操作是 row-wise 的,即每行都算一次 softmax,所以需要用到平铺算法来分块计算 softmax。

注意力机制的计算过程是:

矩阵乘法 --> scale --> mask --> softmax --> dropout --> 矩阵乘法

矩阵乘法和逐点操作(scale,mask,dropout)的分块计算是容易实现的,难点在于 softmax 的分块计算。由于计算 softmax 的归一化因子(分母)时,需要获取到完整的输入数据,进行分块计算的难度比较大。

2.2 safe softmax

原始 softmax 数值不稳定。在实际硬件中,浮点数表示的范围是有限的,对于 float32 和 bfloat16 来说,当 x ≥ 89 x \geq 89 x89 时, e x e^x ex 就会变得很大甚至变成inf,发生数据上溢的问题。为了避免发生数值溢出的问题,保证数值稳定性,计算时通常会减去最大值,称为safe softmax

FlashAttention 采用 safe softmax:

m ( x ) = m a x ( [ x 1 , x 2 , ⋯   , x d ] ) m(x) = max([x_1, x_2, \cdots, x_d]) m(x)=max([x1,x2,,xd])

s o f t m a x ( x j ) = e x j − m ( x ) ∑ i = 1 k e x i − m ( x ) softmax(xj) = \frac{e^{x_j - m(x)}}{\sum _{i=1} ^k e^{x_i - m(x)}} softmax(xj)=i=1kexim(x)exjm(x)

这里通过引入 m ( x ) m(x) m(x),不仅能够一定程度上解决 softmax 的数值稳定性问题,还 使得在不同的 block 间汇总计算 softmax 成为了可能。

2.3 recomputation(重计算)

FlashAttention 算法的目标:在计算中减少显存占用,从 O ( N 2 ) O(N^2) O(N2) 大小降低到线性,把数据加载到 SRAM 中,提高 IO 速度。

传统 Attention 在计算中需要用到 Q,K,V 去计算 S,P 两个矩阵,FlashAttention 引入 softmax 中的统计量 ( m , l ) (m, l) (m,l) ,结合 output O 和在 SRAM 中的 Q,K,V 块进行计算。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值