flash attention 参数(笔记)

本文介绍了Flash Attention的官方版本及安装方法,需确保Linux外界与conda虚拟环境中cuda版本一致,安装好c++、g++、ninja。还详细阐述了其常见函数,如flash_attn_varlen_qkvpacked_func等,这些函数能高效计算注意力输出,支持变长序列、多种注意力模式等,在自然语言处理等领域应用广泛。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

目录

一、flash attention官方

  1.1、flash attention安装

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

  2.2、flash_attn_varlen_kvpacked_func

  2.3、flash_attn_varlen_func

  ​​​​​​​2.4、flash_attn_with_kvcache

  2.5、flash_attn_func


一、flash attention官方

版本: flash-attn  2.5.7

flash-attention/flash_attn/flash_attn_interface.py at main · Dao-AILab/flash-attention · GitHubFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py

  1.1、flash attention安装(下文1.2节)

Trl SFT: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)_unsloth 多卡-CSDN博客文章浏览阅读812次,点赞18次,收藏25次。第三 参考官方命令: https://github.com/Dao-AILab/flash-attention。第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致。第二 安装好 c++ g++ ninja。_unsloth 多卡https://blog.csdn.net/qq_16555103/article/details/137677561

二、flash attention 常见函数

  2.1、flash_attn_varlen_qkvpacked_func

输入前、输出后需要使用unpad、pad

该函数 flash_attn_varlen_qkvpacked_func 用于计算可变长序列的注意力输出,其中 query、key 和 value 已被打包成一个张量。

该函数的主要作用是:

  1. 高效计算注意力输出:通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。

  2. 支持变长序列:函数通过 cu_seqlens 参数接收每个序列的累积长度,可以有效处理变长序列的情况。

  3. 支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。

  4. 提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。

  5. 返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。

该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。

该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式和变长序列输入。

def flash_attn_varlen_qkvpacked_func(
    qkv,  # query、key 和 value 张量,形状为 (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起。例如,如果批次大小为 2,序列长度分别为 3 和 5,头数为 4,head 维度为 64,则 qkv 的形状为 (8, 3, 4, 64)
    cu_seqlens,  # 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,用于索引 qkv。例如,如果批次大小为 2,序列长度分别为 3 和 5,则 cu_seqlens 为 [0, 3, 8]
    max_seqlen,  # 批次中序列的最大长度,整数值。例如,如果批次中的序列长度分别为 3 和 5,则 max_seqlen 为 5
    dropout_p=0.0,  # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。通常在训练时使用一个小于 1 的值,如 0.1,而在评估时设置为 0.0
    softmax_scale=None,  # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)。如果 headdim 为 64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
    causal=False,  # 是否应用因果注意力掩码(用于自回归建模)。如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在语言模型等自回归任务中很有用
    window_size=(-1, -1),  # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 (2, 2),则查询位置 i 只能关注位置 [i-2, i+2] 范围内的键
    alibi_slopes=None,  # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads 为 4,则 alibi_slopes 可以是形状为 (4,) 的张量
    deterministic=False,  # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。通常在评估时不需要使用确定性实现,因为评估时不需要计算梯度
    return_attn_probs=False,  # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
    """
    dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。

    如果 query、key 和 value 已经打包成一个张量,调用这个函数会比调用 flash_attn_varlen_func 更快,因为反向传播避免了显式拼接 query、key 和 value 的梯度,从而减少了内存复制和计算量。

    对于多查询注意力(MQA)和分组查询注意力(GQA),请参见 flash_attn_varlen_kvpacked_func 和 flash_attn_varlen_func。

    如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在 [i - window_size[0], i + window_size[1]] 范围内(包括边界)的键。

    参数说明:
    qkv: (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起
    cu_seqlens: (batch_size + 1,),数据类型为 torch.int32,表示批次中每个序列的累积长度,用于索引 qkv
    max_seqlen: 整数,批次中序列的最大长度
    dropout_p: 浮点数,dropout 概率
    softmax_scale: 浮点数,在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)
    causal: 布尔值,是否应用因果注意力掩码(用于自回归建模)
    window_size: (left, right),整数元组,如果不是 (-1, -1),则实现滑动窗口局部注意力
    alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32 张量,一种用于注意力分数偏置的方法
    deterministic: 布尔值,是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的
    return_attn_probs: 布尔值,是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确

    返回值:
    out: (total, nheads, headdim),注意力层的输出
    softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数)
    S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留)
    """

    return FlashAttnVarlenQKVPackedFunc.apply(
        qkv,  # query、key 和 value 张量
        cu_seqlens,  # 序列的累积长度
        max_seqlen,  # 批次中序列的最大长度    
        dropout_p,  # dropout 概率
        softmax_scale,  # softmax 缩放因子
        causal,  # 是否应用因果注意力掩码
        window_size,  # 滑动窗口大小,用于实现局部注意力
        alibi_slopes,  # 注意力分数偏置方法
        deterministic,  # 是否使用确定性反向传播实现
        return_attn_probs,  # 是否返回注意力概率(
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值