Flash Attention1-真正意义上的scale dot product attention的算子融合(从算法层面加速训练)

本篇文章和下篇文章精简记录一下flashattention1和flashattention2的主要思想和解决问题,首发于我的公众号“AI不止算法”,文章链接在此,本文先记录flash attention1.

引言/动机

FlashAttention是继standard Attention和Memory-efficient Attention(论文:《Self-attention Does Not Need O(N^2) Memory》后,当前主流的Attention优化方法。

在transformer或MHA的profiling中,一般最耗时的部分都在scaled dot product attention部分,它之所以耗时,不是因为计算量大,而是访存量大,访存量大的原因是QKT矩阵的shape为[batch_size, num_heads, seqlen, seqlen],读写内存量与seqlen成平方阶,这导致对HBM或者显存的memory traffic非常大;而且在反向传播中,为了计算QKT matmul输入和权重的梯度,还要再读[batch_size, num_heads, seqlen, seqlen]的QKT来做矩阵乘法,这就导致在训练的计算图中,我们要保存[batch_size, num_heads, seqlen, seqlen]大小的数据在HBM或者显存直到这几个算子(matmul、scale、mask、softmax)的反向传播计算完成。这是非常耗显存的。

图片

解决这种访存问题的惯用手法是算子融合,即把以上左图的这几个算子融合成一个算子。

挑战

在Flash Attention出来之前,已经有了很多了fusedattention算子,但是仔细看,可以发现这其实不是真正的融合算子,只是把matmul kernel、scale kernel、softmax kernel的接口在一个fusedattention算子里面按照计算顺序调了一下,这种手法最多减少了pytorch、TF等框架对算子的调度开销,其实不能真正解决对HBM或者显存的memory traffic。

融合MHA的挑战在于两点:

1.解决softmax,因为softmax是一个row-wise(以行为单位)的操作,必须要遍历softmax一行才能得到结果,由此,后面的matmul不得不等待这个过程,导致并行性降低

2.在寄存器和shared memory复用数据做计算,而不是去HBM或显存上去读数据来计算,然而寄存器数量和shared memory大小都有限,在1的情况下,显然无法将softmax的结果存到这两个存储单元里面供下一个matmul复用,下一个matmul不得不去HBM或显存上读数据

下图是GPU上HBM、shared memory的容量和带宽,和CPU上的memory层次结构也类似,越快的存储单元容量越小带宽越大。
在这里插入图片描述

Flash Attention

flash attention的出现,解决了上述两个挑战,通过分块计算softmax得到一个中间结果,并保存在shared memory,送到下一个matmul计算。

具体算法:

# 计算公式
S = Q * (K.transpose(-1, -2))
P = softmax(S) 
O = P * V

在这里插入图片描述

总的来说,关键在于,把QKVO四个矩阵分块计算,KV矩阵的列维度为外循环(L5),QO矩阵的行维度为内循环(L7),每次外循环我们将得到一行softmax的子块max(分子)和子块sum(分母)(L10),然后把这个中间结果与V矩阵相乘得到O的一行中间结果,而后随着外循环完成,O的一行中间结果不断的更新得到最终结果(L12)
在这里插入图片描述

这么说有点抽象,下面通过例子来说明

例子说明:附注释

# 计算公式
# S = Q * (K.transpose(-1, -2))
# P = softmax(S) 
# O = P * V
import numpy as np

N = 4
d = 2

S = np.random.random(size=(1, N))
V = np.random.random(size=(N, d))

def tiled_softmax_then_matmul(S, V):
  acc = np.zeros(shape=(1, d))
  pre_max = float("-inf")
  pre_sum = 0
  for i  in range(N): # 每个token,KV的列维度,为了简洁,这里把Q的行维度设为了1,因此没有了内循环
    s_i = S[:,i] # 每列S
    cur_max = max(pre_max, s_i) # 当前分块和之前分块一起的最大值
    pre_sum *= (np.exp(pre_max - cur_max)) # L10
    cur_sum = pre_sum + np.exp(s_i - cur_max) # 当前分块和之前分块一起的指数和
    score = np.exp(s_i - cur_max) / cur_sum # 当前分块的softmax结果
    scale = pre_sum / cur_sum # 因为上一个分块的结果是基于当时的softmax中间sum组成的分母(presum),现在这个分块又得到了新的中间sum(cursum),所以需要更新:对上一个分块的结果acc做一个scale,保证结果的正确性
    acc *= scale 
    acc += score * V[i,] # scale后的中间结果加上当前分块的P * V = O
    pre_max = cur_max 
    pre_sum = cur_sum 
  return acc

具体例子:
在这里插入图片描述
在这里插入图片描述

融合后

融合前,scale dot product attention由左边五个算子组成,融合后,变成了右边一个。
在这里插入图片描述

结论

Flash attention1的出现极大加速了transformer的训练速度,如下在这里插入图片描述

不过还是有很多可改进的地方,比如Q的循环可以放在外面,消除每次KV外循环都要去访问Q的开销;每次外循环都需要去用presum/cursum去rescale中间结果以得到最终结果,这些计算可以通过算法消除,使得尽可能减少CUDA core的计算,增加tensor core的计算,因为tensorcore的FLOPS比cuda core大得多,然而在多数真实情况下,tensor core和cuda core很难overlap,或者overlap的比率很低,所以减少cuda core的计算以提升性能。这些都会在下一篇文章介绍flashattention2介绍到。敬请期待~

欢迎关注我的公众号“AI不止算法”

  • 18
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值