FlashAttention知识点笔记

文章探讨了GPU内存层次结构在Transformer模型中的应用,特别是FlashAttention如何通过分块Softmax和SRAM计算来减少显存开销,以及如何在前向传播中避免存储注意力计算中间结果。文章还提到反向传播时如何通过重新计算策略获取梯度信息,以优化内存使用。
摘要由CSDN通过智能技术生成

参考笔记:https://readpaper.feishu.cn/docx/AC7JdtLrhoKpgxxSRM8cfUounsh

1. GPU的特性

GPU的内存层级图如上面所示,其中HBM是高带宽内存,也就是我们常说的显存。它通过将内存芯片直接堆叠在逻辑芯片上,提供了极高的带宽和更低的能耗,从而实现了高密度和高带宽的数据传输,SRAM和DRAM分别是静态随机存储器和动态随机存储器,其中GPU SRAM的速度最快但容量最少,FlashAttention在前向传播时使用分块softmax的方法实现整个注意力计算在GPU SRAM中运行而尽量减少对GPU HBM的访问,从而提高了计算速度,并且不存储注意力计算的中间结果S,P,从而减少了注意力机制的显存开销。   

2 传统注意力计算方法

传统注意力的缺点

Transformer 结构已成为自然语言处理和图像分类等应用中最常用的架构。尽管 Transformer 在规模上不断增大和加深,但处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有二次方的时间和内存复杂度。这导致在处理长序列时速度变慢且内存需求巨大。因此,我们需要一些优化算法来提高注意力模块的计算速度和内存利用率

3.flashattention前向过程

3.1分块Softmax操作

在传统算法中,一种方式是将Mask和SoftMax部分融合,以减少访存次数。然而,FlashAttention则更加激进,它将从输入QKV到输出O的整个过程进行融合,以避免S,P矩阵的存储开销,实现端到端的延迟缩减。然而,由于输入的序列长度N通常很长,无法完全将完整的Q K V O 及中间计算结果存储在SRAM中。因此,需要依赖HBM进行访存操作,与原始计算延迟相比没有太大差异,甚至会变慢(没具体测)。下面展示普通softmax计算公式,与传统公式不同的是多了一步 -max(x) 的步骤,主要是为了数值稳定,防止溢出。分子分母同除e^max(x)其实结果等价,但计算时每一项值不会特别大

为了让计算过程的结果完全在SRAM中,摆脱对HBM的依赖,可以采用分片操作,前面直接一整行(如xi)进行softmax,而这里由于SRAM无法直接存储一整行的xi,所以可以把xi分成两个x(1)x(2)两个block,然后分别求出他们对应的f(x)

 每次进行部分计算,确保这些计算结果能在SRAM内进行交互,待得到对应的结果后再进行输出。当我们将输入进行分片后,无法对完整的行数据执行Softmax操作。这是因为Softmax函数在计算时需要考虑整个行的数据。然而,我们可以通过如下所示方法来获得与完整行Softmax相同的结果,而无需使用近似操作

3.2 前向代码实现

@triton.jit
def _fwd_kernel(
    #输入QKV矩阵,softmax使用的dk(scale)
    Q, K, V, sm_scale,
    #logsum   max值
    L, M,
    #最后输出结果
    Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vk, stride_vn,
    stride_oz, stride_oh, stride_om, stride_on,
    Z, H, N_CTX,
    BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_hz = tl.program_id(1)
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_DMODEL)
    off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
    off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
    off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
    # Initialize pointers to Q, K, V
    q_ptrs = Q + off_q
    k_ptrs = K + off_k
    v_ptrs = V + off_v
    # initialize pointer to m and l
    m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
    # 直接把q 给load进SRAM
    q = tl.load(q_ptrs)
    # loop over k, v and update accumulator
    for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
        # -- compute qk ----
        k = tl.load(k_ptrs)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k)
        qk *= sm_scale
        qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
        # compute new m
        m_curr = tl.maximum(tl.max(qk, 1), m_prev)
        # correct old l
        l_prev *= tl.exp(m_prev - m_curr)
        # attention weights
        p = tl.exp(qk - m_curr[:, None])
        l_curr = tl.sum(p, 1) + l_prev
        # rescale operands of matmuls
        l_rcp = 1. / l_curr
        p *= l_rcp[:, None]
        acc *= (l_prev * l_rcp)[:, None]
        # update acc
        p = p.to(Q.dtype.element_ty)
        v = tl.load(v_ptrs)
        acc += tl.dot(p, v)
        # update m_i and l_i
        l_prev = l_curr
        m_prev = m_curr
        # update pointers
        k_ptrs += BLOCK_N * stride_kn
        v_ptrs += BLOCK_N * stride_vk
    # rematerialize offsets to save registers
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    # write back l and m
    l_ptrs = L + off_hz * N_CTX + offs_m
    m_ptrs = M + off_hz * N_CTX + offs_m
    tl.store(l_ptrs, l_prev)
    tl.store(m_ptrs, m_prev)
    # initialize pointers to output
    offs_n = tl.arange(0, BLOCK_DMODEL)
    off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
    out_ptrs = Out + off_o
    tl.store(out_ptrs, acc)

同时,在前向的过程中,我们采用分块计算的方式,避免了中间计算过程S, P矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗

4. flashAttention的反向传播过程

前向的时候我们为了减少hbm的访问次数,降低内存消耗,我们其实没有对这个sp矩阵进行存储,但我们在反向传播计算梯度的时候,确实是需要sp矩阵的信息,我们这里则采用重新计算的方式来计算对应的梯度,在上面前向计算的时候我们不会存储sp矩阵,但我们会存储对应 的指数之和L来进行梯度的计算

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值