FlexAttention 详解:PyTorch 代码实现多种 Attention 变体

背景

  • 类似 FlashAttention 这样优化过的 Attention 实现虽然极大提升了性能,但是这种效率的提升却牺牲了灵活性。无法再通过编写一些 PyTorch 操作符来尝试新的 Attention 变体,通常需要编写一个新的自定义内核。

    • 一些 attention 变体包含:causal、相对位置嵌入、Alibi、滑动窗口 attention、PrefixLM、Document Masking/Sample Packing/Jagged Tensors, Tanh Soft-Capping, PagedAttention。更麻烦的时候有时候人们还希望组合上面这些变体。
  • 下图左侧是现在的状态,一些掩码 + 偏置 + 设置的组合已经有了现成的内核实现。但各种选项导致了设置数量的指数增长,因此整体支持非常不均衡。更糟糕的是,研究人员提出的新 Attention 变体将根本得不到支持。为了彻底解决这个超立方体问题,引入了 FlexAttention。

    • FlexAttention 提供了一个灵活的 API,允许使用几行惯用的 PyTorch 代码实现多种 Attention 变体(包括博客文章中提到的所有变体)。
    • 通过 torch.compile 将其降低为一个融合的 FlashAttention 内核,生成一个不会产生额外内存且性能与手写内核相当的 FlashAttention 内核。
    • 自动生成反向传播,利用 PyTorch 的 autograd 机制。
    • 可以利用 Attention 掩码中的稀疏性,相比标准的 Attention 实现显著提高了性能。

在这里插入图片描述

flexattention 相关文档

  • https://pytorch.org/docs/main/nn.attention.flex_attention.html#module-torch.nn.attention.flex_attention
  • https://pytorch.org/blog/flexattention/
  • https://github.com/pytorch-labs/attention-gym

FlexAttention 介绍

FlexAttention 的实现形式

这是经典的 Attention 方程:
在这里插入图片描述
代码形式:

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
probabilities = softmax(score, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

FlexAttention 允许用户定义一个函数 score_mod
在这里插入图片描述
代码形式:

Q, K, V: Tensor[batch_size, num_heads, sequence_length, head_dim]
score: Tensor[batch_size, num_heads, sequence_length, sequence_length] = (Q @ K) / sqrt(head_dim)
modified_scores: Tensor[batch_size, num_heads, sequence_length, sequence_length] = score_mod(score)
probabilities = softmax(modified_scores, dim=-1)
output: Tensor[batch_size, num_heads, sequence_length, head_dim] = probabilities @ V

这个函数允许你在 softmax 之前修改 Attention 分数。这对于绝大多数 Attention 变体来说已经足够了。

具体来说,输入的 score 是一个表示查询 token 和键 token 点积的 PyTorch 标量。其余参数告诉你当前正在计算哪个点积——b(批量中的当前元素),h(当前头),q_idx(查询中的位置),kv_idx(键/值张量中的位置)。

def score_mod(score: f32[], b: i32[], h: i32[], q_idx: i32[], kv_idx: i32[]):
    return score  # 无操作 - 标准 Attention

为了应用这个函数,我们可以将其实现为:

for b in range(batch_size):
    for h in range(num_heads):
        for q_idx in range(sequence_length):
            for kv_idx in range(sequence_length):
                modified_scores[b, h, q_idx, kv_idx] = score_mod(scores[b, h, q_idx, kv_idx], b, h, q_idx, kv_idx)

这并不是 FlexAttention 底层的实现方式。通过利用 torch.compile,torch 会自动将你的函数降低为单个融合的 FlexAttention 内核。

SCORE MOD 式样样例

完全注意力 (Full Attention)

完全注意力是标准的双向注意力。在这种情况下,score_mod 是一个无操作的函数——它接收输入的分数,然后按原样返回它们。

def noop(score, b, h, q_idx, kv_idx):
    return score

要将其端到端使用(包括前向和后向):

from torch.nn.attention.flex_attention import flex_attention
flex_attention(query, key, value, score_mod=noop).sum().backward()
相对位置编码 (Relative Position Encodings)

一个常见的 Attention 变体是“相对位置编码”,它不是对查询和键进行绝对距离编码,而是根据查询和键之间的“距离”调整分数。

def relative_positional(score, b, h, q_idx, kv_idx):
    return score + (q_idx - kv_idx)
ALiBi 偏置

在这里插入图片描述
ALiBi 在《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》一文中被提出,并声称在推理时具有有益的长度外推属性。
Alibi 类似于相对位置编码,唯一的例外是它具有一个通常预先计算的 per-head factor。

alibi_bias = generate_alibi_bias()  # [num_heads]

def alibi(score, b, h, q_idx, kv_idx):
    bias = alibi_bias[h] * (q_idx - kv_idx)
    return score + bias
Soft-capping

Soft-capping 是一种在 Gemma2 和 Grok-1 中使用的技术,可以防止 logits 过度增大。

softcap = 20
def soft_cap(score, b, h, q_idx, kv_idx):
    score = score / softcap
    score = torch.tanh(score)
    score = score * softcap
    return score
因果掩码
  • 《Attention is All You Need》论文以及绝大多数 LLMs 使用的是一种仅解码的设置,其中每个标记只能关注之前的标记。人们通常将其视为下三角掩码。
def causal_mask(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, -float("inf"))

基本上,如果查询标记在键标记“之后”,我们保留分数。否则,我们通过将其设置为 -inf 来屏蔽它,从而确保它不会参与 softmax 计算。

因果掩码优化版
  • 如果某些东西被屏蔽了,我们可以完全跳过它的计算!在这种情况下,因果掩码具有大约 50% 的稀疏性,因此不利用稀疏性会导致 2 倍的速度减慢。
  • 通过 mask_mod 可以实现带 mask 的稀疏性
# returns True if this position should participate in the computation
mask_mod(b, h, q_idx, kv_idx) => bool
  • 通过 create_block_mask 来实现高效版本 causal attention
from torch.nn.attention.flex_attention import create_block_mask

def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

# 由于稀疏模式与批次和头无关,我们将它们设置为 None(这会广播它们)
block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=1024, KV_LEN=1024)
# 在这种情况下,我们不需要 score_mod,因此不会传入它。不过,如果你需要额外的灵活性,score_mod 仍然可以与 block_mask 结合使用。
flex_attention(query, key, value, block_mask=block_mask)

请注意,create_block_mask 是一个相对耗时的操作!尽管 FlexAttention 不会因为掩码变化而重新编译,但如果你不小心缓存它,可能会导致显著的速度减慢
在这里插入图片描述
尽管 TFlops 大致相同,但 mask_mod 版本的执行时间快了两倍!这表明我们可以在不损失硬件效率的情况下,利用 BlockMask 提供的稀疏性。

滑动窗口 + 因果掩码
  • 由 Mistral 推广的滑动窗口注意力(也称为局部注意力)利用了最近的 token 最有用的直觉。特别地,它允许查询 token 仅关注最近的 1024 个 token。这通常与因果注意力一起使用。
    在这里插入图片描述
  • 代码实现
SLIDING_WINDOW = 1024

def sliding_window_causal(b, h, q_idx, kv_idx):
    causal_mask = q_idx >= kv_idx
    window_mask = q_idx - kv_idx <= SLIDING_WINDOW 
    return causal_mask & window_mask

# 以下是更简洁的实现方式...
from torch.nn.attention import or_masks

def sliding_window(b, h, q_idx, kv_idx)
    return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = or_masks(causal_mask, sliding_window)
  • 测试结果
    • 我们将其与使用滑动窗口掩码的 F.scaled_dot_product_attention 以及使用因果掩码的 FA2 进行了性能基准测试(作为性能的参考点)。我们不仅比 F.scaled_dot_product_attention 显著更快,而且也显著快于带因果掩码的 FA2,因为这种掩码具有显著更多的稀疏性。
      在这里插入图片描述
PrefixLM
  • T5 架构,提出于《Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer》一文中,描述了一种注意力变体:在“前缀”上执行全双向注意力,在其余部分执行因果注意力。我们同样通过组合两个掩码函数来实现这一点,一个用于因果掩码,另一个基于前缀长度。
    在这里插入图片描述
  • 代码实现
prefix_length: [B]
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx <= prefix_length[b]

prefix_lm_causal = or_masks(prefix_mask, causal_mask)
# 在这种情况下,我们的掩码会根据每个序列的不同而有所不同,因此我们将 B 设置为批量大小
block_mask = create_block_mask(prefix_lm_causal, B=B, H=None, S, S)

就像 score_mod 一样,mask_mod 允许我们引用函数输入中未显式包含的其他张量!然而,使用 PrefixLM 时,稀疏模式会随着输入的变化而变化。这意味着对于每个新的输入批次,我们需要重新计算 BlockMask。常见的做法是,在模型开始时调用 create_block_mask 并在模型的所有注意力调用中重复使用该 block_mask。

文档掩码/不规则序列
  • 另一种常见的注意力变体是文档掩码/不规则序列。假设你有许多长度不同的序列。你想要一起训练它们,但不幸的是,大多数运算符只接受矩形张量。
    在这里插入图片描述

  • 通过 BlockMask,我们也可以在 FlexAttention 中高效地支持这一点!

    • 首先,我们将所有序列展平成一个包含 sum(sequence lengths) 个 token 的序列。
    • 然后,我们计算每个 token 所属的文档 ID。
    • 最后,在我们的 mask_mod 中,我们只需判断查询 token 和键值 token 是否属于同一个文档!
  • 代码实现

# 每个标记所属的文档。
# 例如,[0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] 对应序列长度 3、2 和 6。
document_id: [SEQ_LEN]

def document_masking(b, h, q_idx, kv_idx):
    return document_id[q_idx] == document_id[kv_idx]
块对角掩码

关于文档掩码的一个有趣方面是,容易看出它如何与其他任意组合的掩码结合。例如,我们在上一节中已经定义了 prefixlm_mask。现在我们是否需要定义一个 prefixlm_document_mask 函数呢?

在这些情况下,我们发现一种非常有用的模式是所谓的“更高级别的修改”。在这种情况下,我们可以将现有的 mask_mod 自动转换为适用于不规则序列的版本!

def generate_doc_mask_mod(mask_mod, document_id):
    # 获取唯一文档 ID 及其计数
    _, counts = torch.unique_consecutive(document_id, return_counts=True)
    # 创建累计计数(偏移量)
    offsets = torch.cat([torch.tensor([0], device=document_id.device), counts.cumsum(0)[:-1]])
    def doc_mask_wrapper(b, h, q_idx, kv_idx):
        same_doc = document_id[q_idx] == document_id[kv_idx]
        q_logical = q_idx - offsets[document_id[q_idx]]
        kv_logical = kv_idx - offsets[document_id[kv_idx]]
        inner_mask = mask_mod(b, h, q_logical, kv_logical)
        return same_doc & inner_mask
    return doc_mask_wrapper

例如,使用上面提到的 prefix_lm_causal 掩码,我们可以将其转换为适用于打包文档的掩码,如下所示:

prefix_length = torch.tensor(2, dtype=torch.int32, device="cuda")
def prefix_mask(b, h, q_idx, kv_idx):
    return kv_idx < prefix_length
prefix_lm_causal = or_masks(prefix_mask, causal_mask)
doc_prefix_lm_causal_mask = generate_doc_mask_mod(prefix_lm_causal, document_id)

在这里插入图片描述
现在,这个掩码是“块-前缀LM-对角”形状的。

常见问题解答

问:FlexAttention 何时需要重新编译?

由于 FlexAttention 利用 torch.compile 来捕获图形,它实际上在很广泛的情况下可以避免重新编译。特别是,即使捕获的张量的值发生变化,也不需要重新编译!

flex_attention = torch.compile(flex_attention)

def create_bias_mod(bias):
    def bias_mod(score, b, h, q_idx, kv_idx):
        return score + bias
    return bias_mod

bias_mod1 = create_bias_mod(torch.tensor(0))
flex_attention(..., score_mod=bias_mod1)  # 在这里编译内核

bias_mod2 = create_bias_mod(torch.tensor(2))
flex_attention(..., score_mod=bias_mod2)  # 无需重新编译!

即使改变了块稀疏性(block-sparsity),也不需要重新编译。但是,如果块稀疏性发生变化,我们确实需要重新计算 BlockMask。

问:何时需要重新计算 BlockMask?

每当块稀疏性发生变化时,我们需要重新计算 BlockMask。尽管计算 BlockMask 比重新编译便宜得多(通常是几百微秒,而不是几秒),但您仍然应该注意不要过度重新计算 BlockMask。

以下是一些常见的模式和一些建议。

掩码从不变化(例如因果掩码)
在这种情况下,您可以简单地预计算块掩码并全局缓存,重复使用它进行所有注意力调用。

block_mask = create_block_mask(causal_mask, 1, 1, S, S)
causal_attention = functools.partial(flex_attention, block_mask=block_mask)

掩码每批次变化(例如文档掩码)
在这种情况下,我们建议在模型开始时计算 BlockMask,并在模型中传递它 - 对所有层重复使用 BlockMask。

def forward(self, x, doc_mask):
    # 在前向传递的开始计算块掩码
    block_mask = create_block_mask(doc_mask, None, None, S, S)
    x = self.layer1(x, block_mask)
    x = self.layer2(x, block_mask)
    ...
    # 将块掩码构造成本摊到所有层
    x = self.layer3(x, block_mask)
    return x

掩码每层变化(例如数据依赖的稀疏性)
这是最难的情况,因为我们无法将块掩码计算的成本摊到多个 FlexAttention 调用上。尽管 FlexAttention 在这种情况下仍然有优势,但实际的 BlockMask 效益取决于您的注意力掩码有多稀疏以及我们能多快地构建 BlockMask。这就引出了一个问题…

问:我们如何更快地计算 BlockMask?

create_block_mask 在内存和计算方面都相当昂贵,因为要确定一个块是否完全稀疏,必须在块中的每个点评估 mask_mod。有几种方法可以解决这个问题:

  • 如果您的掩码在批次大小或头之间是相同的,请确保您对这些维度进行广播(即在 create_block_mask 中将它们设置为 None)。
  • 编译 create_block_mask。不幸的是,目前 torch.compile 不能直接在 create_block_mask 上工作,因为存在一些限制。但是,您可以设置 _compile=True,这将显著减少峰值内存和运行时间(在我们的测试中通常减少一个数量级)。
  • 编写自定义的 BlockMask 构造器。BlockMask 的元数据非常简单(请参阅文档)。它本质上是两个张量。a. num_blocks:为每个查询块计算的 KV 块数。b. indices:为每个查询块计算的 KV 块的位置。

例如,这里有一个针对因果掩码的自定义 BlockMask 构造器。

def create_causal_mask(S):
    BLOCK_SIZE = 128
    # 第一个查询块计算一个块,第二个查询块计算两个块,依此类推。
    num_blocks = torch.arange(S // BLOCK_SIZE, device="cuda") + 1
    # 由于我们总是从左向右计算,
    # 我们可以对每个查询块使用 [0, 1, 2, ...] 的索引。
    indices = torch.arange(S // BLOCK_SIZE, device="cuda").expand(
        S // BLOCK_SIZE, S // BLOCK_SIZE
    )
    num_blocks = num_blocks[None, None, :]
    indices = indices[None, None, :]
    return BlockMask(num_blocks, indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=causal_mask)
问:为什么 score_modmask_mod 不同?mask_mod 不就是 score_mod 的一种特殊情况吗?

这是一个非常敏锐的问题!实际上,任何 mask_mod 都可以轻松转换为 score_mod(我们不建议在实际中使用这个函数!)

def mask_mod_as_score_mod(b, h, q_idx, kv_idx):
    return torch.where(mask_mod(b, h, q_idx, kv_idx), score, -float("inf"))

那么,如果 score_mod 能实现 mask_mod 的所有功能,为什么还要保留 mask_mod 呢?

一个直接的挑战是:score_mod 需要实际的分数值作为输入,但在预计算 BlockMask 时,我们并没有实际的分数值。我们可以通过传递全零值来伪造这些值,如果 score_mod 返回 -inf,那么我们认为它是被掩码了(实际上,我们最初是这么做的!)。

然而,这有两个问题。首先,这种方法是很投机的——如果用户的 score_mod 在输入为 0 时返回 -inf,怎么办?或者,如果用户的 score_mod 用一个大的负值而不是 -inf 进行掩码,怎么办?看起来我们是在用方形钉子填补圆形孔。但是,分离 mask_modscore_mod 还有一个更重要的原因——它在本质上更有效率!

事实证明,对每一个计算出的元素应用掩码实际上是相当昂贵的——我们的基准测试显示性能下降约 15-20%!所以,尽管通过跳过一半的计算我们可以获得显著的速度提升,但由于需要对每个元素进行掩码,我们会损失一部分速度提升!

幸运的是,如果我们可视化因果掩码,我们会发现绝大多数块根本不需要“因果掩码”——它们是完全计算的!只有在对角线上的块是部分计算和部分掩码的,才需要应用掩码。
在这里插入图片描述

块对角掩码

BlockMask 之前告诉我们哪些块需要计算,哪些块可以跳过。现在,我们进一步增强了这个数据结构,告诉我们哪些块是“完全计算”的(即可以跳过掩码)和哪些块是“部分计算”的(即需要应用掩码)。然而,需要注意的是,尽管在“完全计算”的块上可以跳过掩码,但像相对位置嵌入这样的 score_mod 仍然需要应用。

仅仅通过 score_mod,我们无法可靠地判断它的哪些部分是“掩码”。因此,用户必须自己将这些部分分离到 mask_mod 中。

问:BlockMask 需要多少额外的内存?

BlockMask 的元数据大小为 [BATCH_SIZE, NUM_HEADS, QUERY_LEN//BLOCK_SIZE, KV_LEN//BLOCK_SIZE]。如果掩码在批次或头部维度上是相同的,可以在该维度上进行广播以节省内存。

在默认的 BLOCK_SIZE 为 128 的情况下,我们预计内存使用在大多数情况下是微不足道的。例如,对于 100 万的序列长度,BlockMask 只会使用 60MB 的额外内存。如果这是一个问题,您可以增加块大小:create_block_mask(..., BLOCK_SIZE=1024)。例如,将 BLOCK_SIZE 增加到 1024 将使元数据减少到不到 1MB。

问:数值比较如何?

尽管结果不是位相同的,但我们有信心 FlexAttention 的数值精度与 FlashAttention 一样高。我们在广泛的输入范围内比较了 FlashAttention 与 FlexAttention 的差异,涵盖了因果和非因果注意力变体。误差几乎是相同的。

分布图
在这里插入图片描述

性能

总的来说,FlexAttention 的性能几乎与手写的 Triton 内核一样高效,这不足为奇,因为我们在很大程度上利用了手写的 Triton 内核。然而,由于其通用性,我们确实付出了一些性能代价。例如,我们必须支付一些额外的延迟来确定下一个要计算的块。在某些情况下,我们提供了一些内核选项,可以在改变内核行为的同时影响其性能。

  • FlexAttention 实现了 FlashAttention2 性能的 90%,而在反向传播中达到了 85%。目前,FlexAttention 使用了一种确定性算法,该算法比 FAv2 重新计算了更多的中间变量,但我们计划改进 FlexAttention 的反向算法,并希望缩小这一差距!
    在这里插入图片描述
    在这里插入图片描述

局限性和未来工作

  • FlexAttention 目前在 PyTorch 的夜间版本中可用,我们计划在 2.5.0 中将其作为原型功能发布。
  • 我们在此没有涉及如何将 FlexAttention 用于推理(或如何实现 PagedAttention)——我们将在后续文章中介绍这些内容。
  • 我们正在努力提高 FlexAttention 的性能,以在 H100 GPU 上匹配 FlashAttention3 的表现。
  • FlexAttention 要求所有序列长度必须是 128 的倍数——我们将很快解决这个问题。
  • 我们计划很快添加 GQA 支持——目前,你可以通过复制 kv 头来实现。

总结

  • FlexAttention 提供了比 FlashAttention 更加灵活的接口,同时保持了几乎相同的性能。虽然在某些特定情况下 FlexAttention 可能略逊于 FlashAttention,但在大多数情况下,两者之间的差距是微不足道的。
  • 12
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值