flash attention 参数(笔记)

目录

一、flash attention官方

  1.1、flash attention安装

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

  2.2、flash_attn_varlen_kvpacked_func

  2.3、flash_attn_varlen_func

  ​​​​​​​2.4、flash_attn_with_kvcache

  2.5、flash_attn_func


一、flash attention官方

版本: flash-attn  2.5.7

flash-attention/flash_attn/flash_attn_interface.py at main · Dao-AILab/flash-attention · GitHubFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py

  1.1、flash attention安装(下文1.2节)

Trl SFT: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)_unsloth 多卡-CSDN博客文章浏览阅读812次,点赞18次,收藏25次。第三 参考官方命令: https://github.com/Dao-AILab/flash-attention。第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致。第二 安装好 c++ g++ ninja。_unsloth 多卡https://blog.csdn.net/qq_16555103/article/details/137677561

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_qkvpacked_func 用于计算可变长序列的注意力输出,其中 query、key 和 value 已被打包成一个张量。

该函数的主要作用是:

  1. 高效计算注意力输出:通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。

  2. 支持变长序列:函数通过 cu_seqlens 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  5. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式和变长序列输入。

def flash_attn_varlen_qkvpacked_func(
    qkv,  # query、key 和 value 张量,形状为 (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起。例如,如果批次大小为 2,序列长度分别为 3 和 5,头数为 4,head 维度为 64,则 qkv 的形状为 (8, 3, 4, 64)
    cu_seqlens,  # 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,用于索引 qkv。例如,如果批次大小为 2,序列长度分别为 3 和 5,则 cu_seqlens 为 [0, 3, 8]
    max_seqlen,  # 批次中序列的最大长度,整数值。例如,如果批次中的序列长度分别为 3 和 5,则 max_seqlen 为 5
    dropout_p=0.0,  # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。通常在训练时使用一个小于 1 的值,如 0.1,而在评估时设置为 0.0
    softmax_scale=None,  # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)。如果 headdim 为 64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
    causal=False,  # 是否应用因果注意力掩码(用于自回归建模)。如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在语言模型等自回归任务中很有用
    window_size=(-1, -1),  # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 (2, 2),则查询位置 i 只能关注位置 [i-2, i+2] 范围内的键
    alibi_slopes=None,  # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads 为 4,则 alibi_slopes 可以是形状为 (4,) 的张量
    deterministic=False,  # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。通常在评估时不需要使用确定性实现,因为评估时不需要计算梯度
    return_attn_probs=False,  # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
    """
    dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。

    如果 query、key 和 value 已经打包成一个张量,调用这个函数会比调用 flash_attn_varlen_func 更快,因为反向传播避免了显式拼接 query、key 和 value 的梯度,从而减少了内存复制和计算量。

    对于多查询注意力(MQA)和分组查询注意力(GQA),请参见 flash_attn_varlen_kvpacked_func 和 flash_attn_varlen_func。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在 [i - window_size[0], i + window_size[1]] 范围内(包括边界)的键。

    参数说明:
    qkv: (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起
    cu_seqlens: (batch_size + 1,),数据类型为 torch.int32,表示批次中每个序列的累积长度,用于索引 qkv
    max_seqlen: 整数,批次中序列的最大长度
    dropout_p: 浮点数,dropout 概率
    softmax_scale: 浮点数,在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)
    causal: 布尔值,是否应用因果注意力掩码(用于自回归建模)
    window_size: (left, right),整数元组,如果不是 (-1, -1),则实现滑动窗口局部注意力
    alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32 张量,一种用于注意力分数偏置的方法
    deterministic: 布尔值,是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的
    return_attn_probs: 布尔值,是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确

    返回值:
    out: (total, nheads, headdim),注意力层的输出
    softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数)
    S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留)
    """

    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv,  # query、key 和 value 张量
        cu_seqlens,  # 序列的累积长度
        max_seqlen,  # 批次中序列的最大长度    
        dropout_p,  # dropout 概率
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果注意力掩码
        window_size,  # 滑动窗口大小,用于实现局部注意力
        alibi_slopes,  # 注意力分数偏置方法
        deterministic,  # 是否使用确定性反向传播实现
        return_attn_probs,  # 是否返回注意力概率(仅用于测试)
    )
  2.2、flash_attn_varlen_kvpacked_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_kvpacked_func 用于计算变长序列的注意力输出,其中 key 和 value 已被打包成一个张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。

该函数的主要作用是:

  1. 高效计算注意力输出:通过将 key 和 value 打包成一个张量作为输入,避免了显式连接 K、V 的梯度,从而提高了计算效率。

  2. 支持变长序列:函数通过 cu_seqlens_q 和 cu_seqlens_k 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。

  5. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  6. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入和 MQA/GQA 等特性,在自然语言处理等领域具有广泛的应用。

def flash_attn_varlen_kvpacked_func(
    q,  # query 张量,形状为 (total_q, nheads, headdim),例如 (1024, 16, 64),其中 total_q=1024 是批次中 query token 的总数,nheads=16 是注意力头数,headdim=64 是每个注意力头的维度
    kv,  # key-value 张量,形状为 (total_k, 2, nheads_k, headdim),例如 (2048, 2, 8, 64),其中 total_k=2048 是批次中 key token 的总数,2 表示 key 和 value 被打包在一起,nheads_k=8 是 key 和 value 的注意力头数,headdim=64 是每个注意力头的维度。注意,nheads_k 可以小于 nheads,这支持了多查询注意力(MQA)和分组查询注意力(GQA)的用法
    cu_seqlens_q,  # query 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 10, 20, 32, 42],表示批次中有 4 个序列,第一个序列长度为 10,第二个序列长度为 10,第三个序列长度为 12,第四个序列长度为 10。这些累积长度用于索引 q 张量
    cu_seqlens_k,  # key 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,例如 [0, 15, 30, 45, 60],用于索引 kv 张量
    max_seqlen_q,  # 批次中 query 序列的最大长度,例如 12,表示批次中最长的 query 序列长度为 12
    max_seqlen_k,  # 批次中 key 序列的最大长度,例如 15,表示批次中最长的 key 序列长度为 15
    dropout_p=0.0,  # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。在训练(training)时,可以设置一个小于 1 的值,例如 0.1,表示将有 10% 的神经元被随机丢弃
    softmax_scale=None,  # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim),例如如果 headdim=64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126,这是一种常见的缩放方式,可以使注意力分数的方差保持在合理范围内,防止出现较大的梯度
    causal=False,  # 是否应用因果注意力掩码(用于自回归建模),如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在诸如语言模型等自回归任务中非常有用
    window_size=(-1, -1),  # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 i, 关注窗口为 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键
    alibi_slopes=None,  # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads=16,则 alibi_slopes 可以是形状为 (16,) 的张量,每个元素对应一个注意力头的偏置斜率。如果设置了该参数,则会对查询 i 和键 j 之间的注意力分数加上一个偏置 (-alibi_slope * |i + seqlen_k - seqlen_q - j|),这种偏置可以鼓励模型关注更靠近查询的键,或者更远离查询的键,具体取决于 alibi_slopes 的值
    deterministic=False,  # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。在评估(evaluation)时,通常不需要使用确定性实现,因为评估时不需要计算梯度
    return_attn_probs=False,  # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
    """
    dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。

    如果 K 和 V 已经打包成一个张量,调用这个函数会比调用 flash_attn_func 更快,因为反向传播避免了显式拼接 K 和 V 的梯度,从而减少了内存复制和计算量。

    支持多查询注意力(MQA)和分组查询注意力(GQA),只需将 KV 的头数设置为少于 Q 的头数即可。注意,Q 的头数必须能被 KV 的头数整除。
    例如,如果 Q 有 6 个头,而 K 和 V 有 2 个头,那么 Q 的头 0、1、2 将注意到 K 和 V 的头 0,而 Q 的头 3、4、5 将注意到 K 和 V 的头 1。
    这种机制使得模型可以在不同的头组之间共享计算资源,从而提高计算效率。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。这种掩码模式确保了在自回归(auto-regressive)建模中,查询只能关注之前的输出,无法关注未来的输出。
    例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0 
        1 0
        1 1
    如果掩码的一行全为零,则该查询的输出也将为零,因为它无法关注任何有效的键。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在
    [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的键。
    例如,如果 seqlen_q = 3, seqlen_k = 5, window_size = (1, 2),则查询位置 0 只能关注键位置 [0, 3],查询位置 1 只能关注键位置 [1, 4],查询位置 2 只能关注键位置 [2, 4]。
    这种局部注意力机制可以显著提高计算效率,尤其是在处理长序列时,因为它减少了需要计算的注意力分数的数量。但同时也会牺牲一些表达能力,因为查询无法关注整个序列。
    """

    # 返回值:
    # out: (total, nheads, headdim),注意力层的输出,例如 (1024, 16, 64)
    # softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数),例如 (4, 16, 12)
    # S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留),例如 (4, 16, 12, 15)

    return FlashAttnVarlenKVPackedFunc.apply(
        q,  # query 张量
        kv,  # 打包的 key-value 张量
        cu_seqlens_q,  # query 序列的累积长度
        cu_seqlens_k,  # key 序列的累积长度 
        max_seqlen_q,  # 批次中 query 序列的最大长度
        max_seqlen_k,  # 批次中 key 序列的最大长度
        dropout_p,  # dropout 概率
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果注意力掩码
        window_size,  # 滑动窗口大小,用于实现局部注意力
        alibi_slopes,  # 注意力分数偏置方法
        deterministic,  # 是否使用确定性反向传播实现
        return_attn_probs,  # 是否返回注意力概率(仅用于测试)
    )

"""
  一个具体的 GQA 示例:
    假设我们有一个批次,其中包含 4 个序列,每个序列的长度分别为 10、10、12 和 10。
    我们希望使用 GQA,其中 query 有 6 个头(nheads=6),而 key 和 value 有 2 个头(nheads_k=2)。
    每个注意力头的维度为 64(headdim=64)。

    在这种情况下,各个参数的值如下:

    q: (total_q, nheads, headdim) = (42, 6, 64)
       total_q = 10 + 10 + 12 + 10 = 42,是批次中所有序列的 token 总数
       nheads = 6,表示 query 有 6 个注意力头
       headdim = 64,表示每个注意力头的维度为 64

    kv: (total_k, 2, nheads_k, headdim) = (42, 2, 2, 64) 
        total_k = 10 + 10 + 12 + 10 = 42,与 total_q 相同,因为每个序列的 query 长度和 key 长度是相同的
        2 表示 key 和 value 被打包在一起
        nheads_k = 2,表示 key 和 value 有 2 个注意力头
        headdim = 64,表示每个注意力头的维度为 64

    cu_seqlens_q: [0, 10, 20, 32, 42],表示批次中每个序列的累积长度,用于索引 q
                  第一个序列的长度为 10,第二个序列的长度为 10,第三个序列的长度为 12,第四个序列的长度为 10

    cu_seqlens_k: [0, 10, 20, 32, 42],与 cu_seqlens_q 相同,因为每个序列的 query 长度和 key 长度是相同的

    max_seqlen_q: 12,批次中最长的 query 序列长度
    max_seqlen_k: 12,与 max_seqlen_q 相同,因为每个序列的 query 长度和 key 长度是相同的

    在这个 GQA 设置下,query 的 6 个头将被分成 2 组,每组 3 个头,分别关注 key 和 value 的 2 个头:
    - query 的头 0 1 2 将关注 key 和 value 的头 0
    - query 的头 3 4 5 将关注 key 和 value 的头 1
 

    这种分组机制允许不同的 query 头关注不同的 key 和 value 子空间,提高了计算效率和表达能力。
    同时,通过减少 key 和 value 的头数,可以显著降低计算和存储开销。
"""
  2.3、flash_attn_varlen_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_func 用于计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。它是 Flash Attention 库中的一个函数,旨在高效地计算注意力输出,同时支持多查询注意力(MQA)和分组查询注意力(GQA)。

该函数的主要作用是:

  1. 高效计算注意力输出:通过分开输入 query、key 和 value 张量,可以有效利用计算资源进行注意力计算。

  2. 支持变长序列:函数通过 cu_seqlens_q 和 cu_seqlens_k 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 支持多查询注意力(MQA)和分组查询注意力(GQA):通过将 key 和 value 的注意力头数设置为少于 query 的注意力头数,可以实现 MQA 和 GQA。例如,如果 query 有 6 个注意力头,key 和 value 有 2 个注意力头,那么 query 的头 0、1、2 将关注 key 和 value 的头 0,query 的头 3、4、5 将关注 key 和 value 的头 1。

  5. 支持分块稀疏注意力:可以通过提供 block_table 参数来启用分块稀疏注意力,进一步提高计算效率。

  6. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  7. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式、变长序列输入、MQA/GQA 和分块稀疏注意力等特性,在自然语言处理等领域具有广泛的应用。

def flash_attn_varlen_func(
    q, # 输入的 query 张量,形状为 (total_q, nheads, headdim),其中 total_q 是批量中所有查询token的总数,nheads 是注意力头数,headdim 是每个注意力头的维度
    k, # 输入的 key 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
    v, # 输入的 value 张量,形状为 (total_k, nheads_k, headdim),其中 total_k 是批量中所有 key token的总数,nheads_k 是 key 的注意力头数,headdim 是每个注意力头的维度
    cu_seqlens_q, # 批量中每个查询序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 q 中索引相应的位置
    cu_seqlens_k, # 批量中每个 key 序列的累积长度,形状为 (batch_size + 1,),数据类型为 torch.int32,用于从 k 和 v 中索引相应的位置
    max_seqlen_q, # 批量中最大查询序列长度
    max_seqlen_k, # 批量中最大 key 序列长度
    dropout_p=0.0, # Dropout 率,在评估(evaluation)时应设置为 0.0
    softmax_scale=None, # softmax 缩放因子,默认为 1 / sqrt(headdim)
    causal=False, # 是否应用因果注意力掩码,用于自回归(auto-regressive)建模
    window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限制上下文窗口
    alibi_slopes=None, # 用于添加注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32
    deterministic=False, # 是否使用确定性反向传播实现,比非确定性实现稍慢但使用更多内存,前向传播始终是确定性的
    return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能不具有正确缩放
    block_table=None # 可选的块表,用于分块稀疏注意力
):
    """
    解决问题: 计算变长序列的注意力输出,其中 query、key 和 value 是分开的张量。支持多查询注意力(MQA)和分组查询注意力(GQA)。

    注意事项:
    - dropout_p 应在评估时设置为 0.0。
    - 支持多查询注意力(MQA)和分组查询注意力(GQA),通过将 K、V 的注意力头数设置为少于 Q 的注意力头数来实现。Q 的注意力头数必须能被 K、V 的注意力头数整除。
      例如,如果 Q 有 6 个注意力头,K、V 有 2 个注意力头,那么 Q 的头 0、1、2 将关注 K、V 的头 0,Q 的头 3、4、5 将关注 K、V 的头 1。
    - 如果 causal=True,因果掩码将与注意力矩阵的右下角对齐。
      例如,如果 seqlen_q = 2 且 seqlen_k = 5,因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
      如果 seqlen_q = 5 且 seqlen_k = 2,因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
      如果掩码的一行全为零,输出也将为零。
    - 如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位于 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内的 key。
    - 可以通过提供 block_table 参数来启用分块稀疏注意力。

    返回值:
    - out: 注意力层的输出张量,形状为 (total, nheads, headdim),其中 total = total_q。
    - softmax_lse [可选,如果 return_attn_probs=True]: 每行的 QK^T * scaling 的 logsumexp 值,形状为 (batch_size, nheads, seqlen),即 softmax 归一化因子的对数。
    - S_dmask [可选,如果 return_attn_probs=True]: softmax 的输出,可能具有不同的缩放,形状为 (batch_size, nheads, seqlen, seqlen),还编码了 Dropout 模式(负值表示该位置被丢弃,非负值表示该位置被保留)。
    """
    return FlashAttnVarlenFunc.apply(
        q, # 输入的 query 张量
        k, # 输入的 key 张量
        v, # 输入的 value 张量
        cu_seqlens_q, # 批量中每个查询序列的累积长度
        cu_seqlens_k, # 批量中每个 key 序列的累积长度
        max_seqlen_q, # 批量中最大查询序列长度
        max_seqlen_k, # 批量中最大 key 序列长度
        dropout_p, # Dropout 率
        softmax_scale, # softmax 缩放因子
        causal, # 是否应用因果注意力掩码
        window_size, # 用于实现滑动窗口局部注意力
        alibi_slopes, # 用于添加注意力分数偏置
        deterministic, # 是否使用确定性反向传播实现
        return_attn_probs, # 是否返回注意力概率
        block_table, # 可选的块表,用于分块稀疏注意力
    )
  2.4、flash_attn_with_kvcache

该函数 flash_attn_with_kvcache 用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。它是一个高效的注意力计算函数,可以在推理时加速序列生成任务。

函数的主要特点和技术如下:

  1. 支持更新 key 和 value 缓存:如果提供了新的 k 和 v 张量,函数会将它们的值原地更新到 k_cache 和 v_cache 中。这对于增量解码非常有用,可以一次性完成缓存更新和注意力计算。

  2. 旋转位置嵌入 (Rotary Position Embedding):如果提供了 rotary_cos 和 rotary_sin,函数会对 qk 应用旋转位置嵌入。旋转位置嵌入是一种编码序列位置信息的方法,可以提高注意力模型在长序列任务中的性能。

  3. 因果注意力掩码 (Causal Attention Mask):如果设置 causal=True,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。

  4. 滑动窗口局部注意力 (Sliding Window Local Attention):如果设置 window_size != (-1, -1),函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。

  5. 多查询和分组查询注意力 (MQA/GQA):函数支持将 q 的头数设置为 kv 头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。

  6. 分块 key/value 缓存:如果提供了 block_table,函数会将 k_cache 和 v_cache 视为分页缓存,支持高效的缓存管理。

  7. 注意力分数偏置 (Alibi Slopes):如果提供了 alibi_slopes,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。

  8. CUDA 内核加速:函数的核心计算由 CUDA 内核 flash_attn_cuda.fwd_kvcache 完成,提供高性能的并行计算能力。

总的来说,flash_attn_with_kvcache 函数集成了多种先进的注意力计算技术,可以高效地

def flash_attn_with_kvcache(
    q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
    k_cache, # key 缓存张量,形状为 (batch_size_cache, seqlen_cache, nheads_k, headdim) 或 (num_blocks, page_block_size, nheads_k, headdim)
    v_cache, # value 缓存张量,形状与 k_cache 相同
    k=None, # 可选的新 key 张量,形状为 (batch_size, seqlen_new, nheads_k, headdim)
    v=None, # 可选的新 value 张量,形状与 k 相同
    rotary_cos=None, # 可选的旋转位置嵌入余弦值,形状为 (seqlen_ro, rotary_dim / 2)
    rotary_sin=None, # 可选的旋转位置嵌入正弦值,形状与 rotary_cos 相同
    cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, # 缓存序列长度,可以是整数或张量
    cache_batch_idx: Optional[torch.Tensor] = None, # 缓存批次索引张量,形状为 (batch_size,)
    block_table: Optional[torch.Tensor] = None, # 可选的块表张量,形状为 (batch_size, max_num_blocks_per_seq)
    softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
    causal=False, # 是否进行因果注意力掩码
    window_size=(-1, -1), # 滑动窗口大小,(-1, -1)表示无限上下文窗口
    rotary_interleaved=True, # 是否交错旋转位置嵌入
    alibi_slopes=None, # 可选的注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
    num_splits=0, # 将 key/value 沿序列维度分割的数量,0 表示自动确定
):
    """
    该函数用于在推理(inference)过程中计算注意力层的输出,同时支持使用 key 和 value 缓存,以及旋转位置嵌入等技术。

    如果 k 和 v 不为 None,则会将它们的新值原地更新到 k_cache 和 v_cache 中。这对于增量解码很有用:
    你可以传入上一步的缓存 key/value,并使用当前步的新 key/value 进行更新,然后使用更新后的缓存进行注意力计算,所有操作都在一个内核中完成。

    如果你传入了 k/v,你必须确保缓存足够大,可以容纳新的值。例如,KV 缓存可以预先分配最大序列长度,并使用 cache_seqlens 跟踪每个序列在批次中的当前长度。

    如果传入了 rotary_cos 和 rotary_sin,则会应用旋转位置嵌入。key @k 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果是因果注意力或局部注意力(即 window_size != (-1, -1)),则查询 @q 将在索引 cache_seqlens、cache_seqlens + 1 等处被 rotary_cos 和 rotary_sin 旋转。
    如果既不是因果注意力也不是局部注意力,则查询 @q 将只在索引 cache_seqlens 处被 rotary_cos 和 rotary_sin 旋转(即我们认为 @q 中的所有标记都位于 cache_seqlens 位置)。

    该函数支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。注意 Q 中的头数量必须能被 KV 中的头数量整除。
    例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
    如果掩码行全为 0,则输出也将为 0。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。

    注意:不支持反向传播。

    参数:
        q: (batch_size, seqlen, nheads, headdim)
        k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) 如果没有 block_table,
            或 (num_blocks, page_block_size, nheads_k, headdim) 如果有 block_table (即分页 KV 缓存)
            page_block_size 必须是 256 的倍数。
        v_cache: 形状与 k_cache 相同
        k [可选]: (batch_size, seqlen_new, nheads_k, headdim)。如果不为 None,我们将它与 k_cache 拼接,从 cache_seqlens 指定的索引开始。
        v [可选]: (batch_size, seqlen_new, nheads_k, headdim)。与 k 类似。
        rotary_cos [可选]: (seqlen_ro, rotary_dim / 2)。如果不为 None,我们将对 k 和 q 应用旋转位置嵌入。只在 k 和 v 被传入时适用。rotary_dim 必须能被 16 整除。
        rotary_sin [可选]: (seqlen_ro, rotary_dim / 2)。与 rotary_cos 类似。
        cache_seqlens: int 或 (batch_size,), 数据类型 torch.int32。KV 缓存的序列长度。
        block_table [可选]: (batch_size, max_num_blocks_per_seq), 数据类型 torch.int32。
        cache_batch_idx: (batch_size,), 数据类型 torch.int32。用于索引 KV 缓存的索引。如果为 None,我们假设批次索引为 [0, 1, 2, ..., batch_size - 1]。如果索引不是唯一的,并且提供了 k 和 v,那么缓存中更新的值可能来自任何重复的索引。
        softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
        causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
        window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
        rotary_interleaved: bool。只在传入 rotary_cos 和 rotary_sin 时适用。如果为 True,旋转位置嵌入将组合维度 0 & 1、2 & 3 等。如果为 False,旋转位置嵌入将组合维度 0 & rotary_dim / 2、1 & rotary_dim / 2 + 1(即 GPT-NeoX 风格)。
        alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
        num_splits: int。如果大于 1,则将 key/value 沿序列维度分割成这么多块。如果 num_splits == 1,我们不分割 key/value。如果 num_splits == 0,我们使用启发式方法自动确定分割数量。除非你知道你在做什么,否则不要更改这个参数。

    返回:
        out: (batch_size, seqlen, nheads, headdim)。
    """
    # 确保 k_cache 和 v_cache 的最后一维是连续的
    assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
    assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
    
    # 如果 q、k 或 v 的最后一维不是连续的,则对它们进行连续化
    maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
    q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
    
    # 如果没有指定 softmax 缩放系数,则使用默认值 1 / sqrt(headdim)
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    
    # 如果 cache_seqlens 是整数,则将其转换为张量
    if cache_seqlens is not None and isinstance(cache_seqlens, int):
        cache_seqlens = torch.full(
            (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device
        )
        cache_seqlens = maybe_contiguous(cache_seqlens)
    
    # 确保 cache_batch_idx 和 block_table 是连续的
    cache_batch_idx = maybe_contiguous(cache_batch_idx)
    block_table = maybe_contiguous(block_table)
    
    # 调用 CUDA 内核计算注意力输出
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
        q,
        k_cache,
        v_cache,
        k,
        v,
        cache_seqlens,
        rotary_cos,
        rotary_sin,
        cache_batch_idx,
        block_table,
        alibi_slopes,
        None,
        softmax_scale,
        causal,
        window_size[0],
        window_size[1],
        rotary_interleaved,
        num_splits,
    )
    
    return out
  2.5、flash_attn_func

该函数 flash_attn_func 用于计算注意力层的输出,是一个高效的注意力计算函数,支持多种先进的注意力技术。

函数的主要特点和技术如下:

  1. 多查询和分组查询注意力 (MQA/GQA):函数支持将 q 的头数设置为 kv 头数的整数倍,实现多查询和分组查询注意力。这种技术可以提高计算效率。

  2. 因果注意力掩码 (Causal Attention Mask):如果设置 causal=True,函数会应用因果注意力掩码,确保模型只关注当前位置之前的输出,实现自回归(auto-regressive)特性。

  3. 滑动窗口局部注意力 (Sliding Window Local Attention):如果设置 window_size != (-1, -1),函数会实现滑动窗口局部注意力,对于每个查询,只关注一定窗口范围内的 key。这可以减少计算开销,适用于一些特定任务。

  4. 注意力分数偏置 (Alibi Slopes):如果提供了 alibi_slopes,函数会为每个查询-key 对的注意力分数加上一个与位置相关的偏置项。这是一种正则化技术,可以改善注意力模型的性能。

  5. 确定性反向传播 (Deterministic Backward Pass):函数支持使用确定性反向传播实现,虽然稍慢但使用更多内存,正向传播始终是确定性的。

  6. 返回注意力概率:如果设置 return_attn_probs=True,函数会返回注意力概率,但这只用于测试,返回的概率可能由于缩放问题而不准确。

  7. CUDA 内核加速:函数的核心计算由 CUDA 内核 FlashAttnFunc.apply 完成,提供高性能的并行计算能力。

总的来说,flash_attn_func 函数集成了多种先进的注意力计算技术,可以高效地计算注意力层的输出,支持各种用途和优化方式。

def flash_attn_func(
    q, # 查询张量,形状为 (batch_size, seqlen, nheads, headdim)
    k, # key 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
    v, # value 张量,形状为 (batch_size, seqlen, nheads_k, headdim)
    dropout_p=0.0, # dropout 概率,评估时应设置为 0.0
    softmax_scale=None, # softmax 缩放系数,默认为 1 / sqrt(headdim)
    causal=False, # 是否应用因果注意力掩码,例如用于自回归建模
    window_size=(-1, -1), # 滑动窗口大小,(-1, -1) 表示无限上下文窗口
    alibi_slopes=None, # 注意力分数偏置,形状为 (nheads,) 或 (batch_size, nheads)
    deterministic=False, # 是否使用确定性反向传播实现,稍慢但使用更多内存,正向传播始终是确定性的
    return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能由于缩放问题而不准确
):
    """
    该函数用于计算注意力层的输出。

    支持多查询和分组查询注意力(MQA/GQA),方法是将 KV 的头数量设置为少于 Q 的头数量。
    注意 Q 中的头数量必须能被 KV 中的头数量整除。
    例如,如果 Q 有 6 个头,而 K、V 有 2 个头,那么 Q 的第 0、1、2 个头将关注 K、V 的第 0 个头,而 Q 的第 3、4、5 个头将关注 K、V 的第 1 个头。

    如果 causal=True,则因果掩码对齐到注意力矩阵的右下角。
    例如,如果 seqlen_q = 2 且 seqlen_k = 5,则因果掩码(1 = 保留,0 = 掩码)为:
        1 1 1 1 0
        1 1 1 1 1
    如果 seqlen_q = 5 且 seqlen_k = 2,则因果掩码为:
        0 0
        0 0
        0 0
        1 0
        1 1
    如果掩码行全为 0,则输出也将为 0。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。
    位置 i 的查询将只关注位置在 [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] 范围内(包括边界)的 key。

    参数:
        q: (batch_size, seqlen, nheads, headdim)
        k: (batch_size, seqlen, nheads_k, headdim)
        v: (batch_size, seqlen, nheads_k, headdim)
        dropout_p: float。dropout 概率。
        softmax_scale: float。QK^T 在应用 softmax 之前的缩放系数。默认为 1 / sqrt(headdim)。
        causal: bool。是否应用因果注意力掩码(例如用于自回归建模)。
        window_size: (left, right)。如果不是 (-1, -1),则实现滑动窗口局部注意力。
        alibi_slopes: (nheads,) 或 (batch_size, nheads), fp32。将 (-alibi_slope * |i + seqlen_k - seqlen_q - j|) 的偏置加到查询 i 和 key j 的注意力分数上。
        deterministic: bool。是否使用确定性反向传播实现,稍慢但使用更多内存。正向传播始终是确定性的。
        return_attn_probs: bool。是否返回注意力概率。这个选项仅用于测试。返回的概率可能由于缩放问题而不准确。

    返回:
        out: (batch_size, seqlen, nheads, headdim)。
        softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen)。每行 QK^T * 缩放系数的 logsumexp (例如,softmax 归一化因子的对数)。
        S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen)。softmax 的输出(可能有不同的缩放)。它还编码了 dropout 模式(负值表示该位置被丢弃,非负值表示被保留)。
    """
    # 调用 FlashAttnFunc 类的 apply 方法计算注意力输出
    return FlashAttnFunc.apply(
        q,
        k,
        v,
        dropout_p,
        softmax_scale,
        causal,
        window_size,
        alibi_slopes,
        deterministic,
        return_attn_probs,
    )

  • 10
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值