目录
2.1、flash_attn_varlen_qkvpacked_func
2.2、flash_attn_varlen_kvpacked_func
2.4、flash_attn_with_kvcache
一、flash attention官方
版本: flash-attn 2.5.7
1.1、flash attention安装(下文1.2节)
二、flash attention 常见函数
2.1、flash_attn_varlen_qkvpacked_func
输入前、输出后需要使用unpad、pad
该函数
flash_attn_varlen_qkvpacked_func
用于计算可变长序列的注意力输出,其中 query、key 和 value 已被打包成一个张量。该函数的主要作用是:
高效计算注意力输出:通过将 query、key 和 value 打包成一个张量作为输入,避免了显式连接 Q、K、V 的梯度,从而提高了计算效率。
支持变长序列:函数通过
cu_seqlens
参数接收每个序列的累积长度,可以有效处理变长序列的情况。支持多种注意力模式:函数支持因果注意力掩码(用于自回归建模)、滑动窗口局部注意力(只关注特定范围内的 key)和添加注意力分数偏置等功能。
提供确定性反向传播选项:可以选择使用确定性反向传播实现,虽然稍慢但使用更多内存,保证了结果的确定性。
返回注意力概率(仅用于测试):可以选择返回注意力概率,但这些概率可能不具有正确的缩放,仅用于测试目的。
该函数的输入参数包括 query、key 和 value 张量、序列长度信息、Dropout 率、softmax 缩放因子、注意力模式选项等。输出则是注意力层的输出张量,以及可选的注意力概率和 softmax 归一化因子。
该函数利用了 PyTorch 的自定义 CUDA 扩展,提供了高效的注意力计算能力,同时支持了多种注意力模式和变长序列输入。
def flash_attn_varlen_qkvpacked_func(
qkv, # query、key 和 value 张量,形状为 (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起。例如,如果批次大小为 2,序列长度分别为 3 和 5,头数为 4,head 维度为 64,则 qkv 的形状为 (8, 3, 4, 64)
cu_seqlens, # 序列的累积长度,形状为 (batch_size + 1),数据类型为 torch.int32,用于索引 qkv。例如,如果批次大小为 2,序列长度分别为 3 和 5,则 cu_seqlens 为 [0, 3, 8]
max_seqlen, # 批次中序列的最大长度,整数值。例如,如果批次中的序列长度分别为 3 和 5,则 max_seqlen 为 5
dropout_p=0.0, # dropout 概率,在评估(evaluation)时应设置为 0.0,以保留所有神经元的输出,防止信息损失。通常在训练时使用一个小于 1 的值,如 0.1,而在评估时设置为 0.0
softmax_scale=None, # 在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)。如果 headdim 为 64,则默认缩放因子为 1 / sqrt(64) ≈ 0.126
causal=False, # 是否应用因果注意力掩码(用于自回归建模)。如果设置为 True,则查询只能关注之前的输出,无法关注未来的输出,这在语言模型等自回归任务中很有用
window_size=(-1, -1), # 用于实现滑动窗口局部注意力,(-1, -1) 表示无限上下文窗口,即不进行任何窗口限制。如果设置为 (2, 2),则查询位置 i 只能关注位置 [i-2, i+2] 范围内的键
alibi_slopes=None, # 一种用于注意力分数偏置的方法,形状为 (nheads,) 或 (batch_size, nheads),数据类型为 fp32。例如,如果 nheads 为 4,则 alibi_slopes 可以是形状为 (4,) 的张量
deterministic=False, # 是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的。通常在评估时不需要使用确定性实现,因为评估时不需要计算梯度
return_attn_probs=False, # 是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确,因为在实际应用中,通常不需要直接访问注意力概率,而是更关注注意力层的输出
):
"""
dropout_p 在评估(evaluation)时应该设置为 0.0,以保留所有神经元的输出,防止信息损失。
如果 query、key 和 value 已经打包成一个张量,调用这个函数会比调用 flash_attn_varlen_func 更快,因为反向传播避免了显式拼接 query、key 和 value 的梯度,从而减少了内存复制和计算量。
对于多查询注意力(MQA)和分组查询注意力(GQA),请参见 flash_attn_varlen_kvpacked_func 和 flash_attn_varlen_func。
如果 window_size != (-1, -1),则实现滑动窗口局部注意力。位置 i 处的查询只会注意到位置在 [i - window_size[0], i + window_size[1]] 范围内(包括边界)的键。
参数说明:
qkv: (total, 3, nheads, headdim),其中 total 是批次中 token 的总数,3 表示 query、key 和 value 被打包在一起
cu_seqlens: (batch_size + 1,),数据类型为 torch.int32,表示批次中每个序列的累积长度,用于索引 qkv
max_seqlen: 整数,批次中序列的最大长度
dropout_p: 浮点数,dropout 概率
softmax_scale: 浮点数,在应用 softmax 之前对 QK^T 进行缩放的系数,默认为 1 / sqrt(headdim)
causal: 布尔值,是否应用因果注意力掩码(用于自回归建模)
window_size: (left, right),整数元组,如果不是 (-1, -1),则实现滑动窗口局部注意力
alibi_slopes: (nheads,) 或 (batch_size, nheads),fp32 张量,一种用于注意力分数偏置的方法
deterministic: 布尔值,是否使用确定性的反向传播实现,这种实现稍慢且内存占用更高,但是前向传播始终是确定性的
return_attn_probs: 布尔值,是否返回注意力概率,仅用于测试,返回的概率可能缩放不正确
返回值:
out: (total, nheads, headdim),注意力层的输出
softmax_lse [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen),每行的 QK^T * scaling 的 logsumexp(即 softmax 归一化因子的对数)
S_dmask [可选,如果 return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen),softmax 的输出(可能有不同的缩放),它也编码了 dropout 模式(负值表示该位置被 dropout,非负值表示被保留)
"""
return FlashAttnVarlenQKVPackedFunc.apply(
qkv, # query、key 和 value 张量
cu_seqlens, # 序列的累积长度
max_seqlen, # 批次中序列的最大长度
dropout_p, # dropout 概率
softmax_scale, # softmax 缩放因子
causal, # 是否应用因果注意力掩码
window_size, # 滑动窗口大小,用于实现局部注意力
alibi_slopes, # 注意力分数偏置方法
deterministic, # 是否使用确定性反向传播实现
return_attn_probs, # 是否返回注意力概率(