【大模型训练】Flash Attention详解

前言

FlashAttention系列工作,是一种加速注意力计算方法,目前已经应用在:GPT-3、Falcon2(阿联酋大模型)、Llama2、Megatron-LM、GPT-4等流行LLM上。并且FlashAttention2已经集成到了pytorch2.0中,可以很便捷的调用。

1. FlashAttention动机Transformers are slow and memory-hungry on long sequences, since the time and memory complexity of self-attention are quadratic in sequence length.,可以看出由于Transformer中self-attention 的时间和内存复杂度是序列长度的二次方,所以序列过长时,算法速度会变慢,需要消耗很高的内存,导致低效的

2. FlashAttention主要贡献

  • FlashAttention利用底层硬件的内存层次知识,例如GPU的内存层次结构,来提高计算速度和减少内存访问开销
  • 核心原理是通过将输入分块,并在每个块上执行注意力操作,从而减少对高带宽内存(HBM)的读写操作
  • FlashAttention减少了内存读写量,从而实现了2-4倍的计算加速

预备知识

1. GPU 内存层次结构

GPU 内存层次结构 包含多种不同大小和速度的内存形式,内存容量越小,读写速度越快。以A100 GPU为例,主要有两种类型,如下图所示:

在这里插入图片描述

  • 高带宽内存 (HBM):也就是我们常说的GPU显存,A100具有 40-80GB HBM,带宽为 1.5-2.0TB/s
  • SRAM:位于GPU片上,每个 108 个流式多处理器(SM, Streaming Multiprocessor)都有 192KB 片上 SRAM,带宽估计约为 19TB/s
  • 二者的位置分布如下图所示,其中HBM在VRAM部门,而SRAM在GPU内部:

在这里插入图片描述

可以看到,片上 SRAM 比 HBM 快一个数量级,但内存容量小很多数量级。随着计算相对于内存速度变得更快,内存 (HBM) 访问越来越成为操作瓶颈。因此,利用快速 SRAM 变得更加重要。

2. GPU执行过程

GPU 有大量线程(threads )来执行操作(称为内核 Kernel)。每个Kernel将输入从 HBM 加载到寄存器和 SRAM,进行计算,然后将输出写入 HBM。

3. 性能特点

根据计算和内存访问的平衡,操作可以分为计算限制或内存限制。这通常通过算术强度来衡量,即内存访问的每个字节的算术运算数量。

  • 计算限制:操作所花费的时间由算术运算的数量决定,而访问 HBM 的时间要少得多。典型的例子是大内部维度的矩阵乘法,以及大量通道的卷积
  • 内存限制:操作所花费的时间由内存访问次数决定,而计算所花费的时间要少得多。示例包括大多数其他操作:逐元素(激活、dropout)和归约(求和、softmax、批量归一化、层归一化)

FlashAttention1

在这里插入图片描述

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
papercode

传统Attention计算方式

传统的Attention计算过程如下:

在这里插入图片描述

  • 首先 Q , K , V Q,K,V Q,K,V矩阵计算好,放在HBM中
  • 接着,为了计算QK注意力得分,将 Q , K Q,K Q,K从HBM中取出来写入SRAM,然后计算 S = Q K T S=QK^T S=QKT,再把 S S S从SRAM写入HBM
  • 然后,从HBM加载 S S S到SRAM,计算 P = S o f t m a x ( S ) P=\rm{Softmax}(\it{S}) P=Softmax(S),然后再把 P P P从SRAM写入HBM
  • 最后,从HBM加载 P , V P,V P,V到SRAM,计算最终的输出 O = P V O=PV O=PV,然后把 O O O写入HBM

可以看到,这个过程中存在多次HBM和SRAM之间的读写操作。同时由于Attention的计算方式,导致中间的临时变量 S , P S,P S,P的参数量和输入序列长度的平方成正比:

在这里插入图片描述

因此,在训练时,长序列输入在计算Attention时会产生更大参数量的临时变量,会占用更大显存空间,导致更多的访问消耗。也就是说,Attention操作主要是内存限制问题,通信时间是制约计算效率的主要因素。

FlashAttention1的基本原理

因此,FlashAttention主要的思想,就是减少通信时间,也就是减少IO操作,使得计算尽可能多的访问片上的SRAM,尽可能少的访问片外的HBM。

  • 通过分块计算,融合多个操作,减少中间结果缓存
  • 反向传播时,重新计算中间结果(类似于梯度检查点的原理)

在这里插入图片描述

除去Softmax操作的分块计算

在计算Attention主要有两个临时变量 S , P S,P S,P,FlashAttention的分块计算,使得不需要存储这两个临时变量,而是直接在SRAM计算得到部分最终结果 O O O,从而减少了内存访问开销。这里我们先忽略Softmax操作,因为在分块计算,他比较麻烦。

这里假设 Q , K , V Q,K,V Q,K,V矩阵的大小为 ( 4 , 3 ) (4, 3) (4,3),那么FlashAttention的分块计算过程如下,由于矩阵乘法的性质,每次分块计算得到的结果,都是最终结果矩阵中的一部分值:
在这里插入图片描述

Softmax分块计算

下面,我们来解决Softmax这个麻烦的操作,首先来回顾Softmax的计算公式:

softmax ⁡ ( { x 1 , … , x N } ) = { e x i ∑ j = 1 N e x j } i = 1 N \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i}}{\sum_{j=1}^N e^{x_j}}\right\}_{i=1}^N softmax({x1,,xN})={j=1Nexjexi}i=1N

但是,如果数据类型为FP16,那么最大可以表示为65536,因此当 x i x_i xi为12时, e 12 = 162754 e^{12}=162754 e12=162754,超过了FP16所能表示的最大值。因此,我们需要使用Safe_Softmax方法来避免这个问题:

m = max ⁡ ( x i ) softmax ⁡ ( { x 1 , … , x N } ) = { e x i / e m ∑ j = 1 N e x j / e m } i = 1 N = { e x i − m ∑ j = 1 N e x j − m } i = 1 N \begin{aligned} & m=\max \left(x_i\right) \\ & \operatorname{softmax}\left(\left\{x_1, \ldots, x_N\right\}\right)=\left\{\frac{e^{x_i} / e^m}{\sum_{j=1}^N e^{x_j} / e^m}\right\}_{i=1}^N=\left\{\frac{e^{x_i-m}}{\sum_{j=1}^N e^{x_j-m}}\right\}_{i=1}^N \end{aligned} m=max(xi)softmax({x1,,xN})={j=1Nexj/emexi/em}i=1N={j=1Nexjmexim}i=1N

也就是在计算Softmax之前,先对输入数据进行归一化处理。此时计算Softmax的流程为:

x = [ x 1 , … , x N ] m ( x ) : = max ⁡ ( x ) p ( x ) : = [ e x 1 − m ( x ) , … , e x N − m ( x ) ] l ( x ) : = ∑ i p ( x ) i softmax ⁡ ( x ) : = p ( x ) l ( x ) \begin{aligned} & x=\left[x_1, \ldots, x_N\right] \\ & m(x):=\max (x) \\ & p(x):=\left[e^{x_1-m(x)}, \ldots, e^{x_N-m(x)}\right] \\ & l(x):=\sum_i p(x)_i \\ & \operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} x=[x1,,xN]m(x):=max(x)p(x):=[ex1m(x),,exNm(x)]l(x):=ip(x)isoftmax(x):=l(x)p(x)

接下来看分块处理时,假设这里分为两块处理,我们首先需要在每个块内找到最大值(使用临时变量来保存),做归一化处理:

x = [ x 1 , … , x N , … x 2 N ] x 1 = [ x 1 , … , x N ] x 2 = [ x N + 1 , … x 2 N ] m ( x 1 ) p ( x 1 ) l ( x 1 ) m ( x 2 ) p ( x 2 ) l ( x 2 ) \begin{aligned} & x=\left[x_1, \ldots, x_N, \ldots x_{2 N}\right] \\ & x^1=\left[x_1, \ldots, x_N\right] \quad x^2=\left[x_{N+1}, \ldots x_{2 N}\right] \\ & m\left(x^1\right) \quad p\left(x^1\right) \quad l\left(x^1\right) \quad m\left(x^2\right) \quad p\left(x^2\right) \quad l\left(x^2\right) \end{aligned} x=[x1,,xN,x2N]x1=[x1,,xN]x2=[xN+1,x2N]m(x1)p(x1)l(x1)m(x2)p(x2)l(x2)

然后再计算数据的全局最大值,并且更新 p ( x ) , l ( x ) p(x),l(x) p(x),l(x),最后计算得到输入数据的Softmax值:

  • 由于全局最大值,一定是各个块内最大值中的一个
  • 因此在更新 p ( x ) , l ( x ) p(x),l(x) p(x),l(x),只需要乘以每个块最大值相对于全局最大值的差值的指数,就可以了
    m ( x ) : = max ⁡ ( m ( x 1 ) , m ( x 2 ) ) p ( x ) : = [ e m ( x 1 ) − m ( x ) p ( x 1 ) , e m ( x 2 ) − m ( x ) p ( x 2 ) ] l ( x ) : = e m ( x 1 ) − m ( x ) l ( x 1 ) + e m ( x 2 ) − m ( x ) l ( x 2 ) softmax ⁡ ( x ) : = p ( x ) l ( x ) \begin{aligned} &\begin{aligned} & m(x):=\max \left(m\left(x^1\right), m\left(x^2\right)\right) \\ & p(x):=\left[e^{m\left(x^1\right)-m(x)} p\left(x^1\right), e^{m\left(x^2\right)-m(x)} p\left(x^2\right)\right] \\ & l(x):=e^{m\left(x^1\right)-m(x)} l\left(x^1\right)+e^{m\left(x^2\right)-m(x)} l\left(x^2\right) \end{aligned}\\ &\operatorname{softmax}(x):=\frac{p(x)}{l(x)} \end{aligned} m(x):=max(m(x1),m(x2))p(x):=[em(x1)m(x)p(x1),em(x2)m(x)p(x2)]l(x):=em(x1)m(x)l(x1)+em(x2)m(x)l(x2)softmax(x):=l(x)p(x)

Attention分块计算

最后,我们来看一下FlashAttention的完整计算流程:
在这里插入图片描述

FlashAttention2

在这里插入图片描述
相比于FlashAttention1的改进:

  • 减少了非矩阵乘法计算,可以利用Tensor Core加速计算
  • 调整了内外训练方式,改为 Q 为外层循环,KV 为内层循环,进一步减少HBM读写,增加了并行度
  • 如果一个Block处于矩阵上三角部分(Mask机制),则不进行attention计算,进一步优化了计算效率

参考资料

  • [1] https://www.bilibili.com/video/BV1UT421k7rA/?share_source=copy_web&vd_source=79b1ab42a5b1cccc2807bc14de489fa7
  • [2] https://zhuanlan.zhihu.com/p/676655352
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

嗜睡的篠龙

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值