Gemma 2 的滑动窗口注意力(Sliding Window Attention)解析:源代码

Gemma 2 的滑动窗口注意力(Sliding Window Attention)解析


Transformer 结构 中,自注意力(Self-Attention)是核心机制之一。然而,标准的自注意力计算复杂度为 ( O ( n 2 ) O(n^2) O(n2) ),随着序列长度增加,计算和内存开销都会急剧增长。为了解决这一问题,Gemma 2 采用了滑动窗口注意力(Sliding Window Attention)机制,通过限制每个 token 只能关注附近的 tokens,大幅降低计算复杂度,使模型能够高效处理更长的文本序列。

本文将详细解析 Gemma 2 的 Sliding Window Attention 的实现原理、代码细节以及它带来的优化效果。


1. 为什么需要滑动窗口注意力?

1.1 标准 Transformer 的计算瓶颈

在标准 Transformer 中,自注意力计算如下:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right) V Attention(Q,K,V)=softmax(dk QKT)V
其中:

  • ( Q , K , V Q, K, V Q,K,V ) 是 查询(Query)、键(Key)、值(Value) 矩阵,形状为 ( ( n , d k ) (n, d_k) (n,dk) )。
  • 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ),其中 ( n n n ) 是序列长度。

问题:

  • 对于 长文本(如 8K、16K token),标准 Transformer 计算量和显存占用急剧上升
  • 由于 每个 token 计算所有 token 的注意力分数,使得 全自注意力不适用于长文本任务

1.2 Sliding Window Attention 如何优化?

滑动窗口注意力(Sliding Window Attention)核心思想:

  • 每个 token 仅关注窗口范围内的 tokens,而不是整个序列。
  • 计算复杂度降低为 ( O ( n ⋅ w ) O(n \cdot w) O(nw) ),其中 ( w w w ) 是窗口大小。
  • 适用于 长文本任务,如文档摘要、问答、代码补全等

📌 示例

  • 如果窗口大小 ( w = 512 ):
    • 第 1000 个 token 仅关注 [488, ..., 1000, ..., 1512] 范围内的 tokens。
    • 而不是像标准 Transformer 一样计算 [0, ..., 1000, ..., n] 之间的所有 token 交互。

🔹 这样,每个 token 只计算窗口内 tokens 的注意力,大幅减少计算量!

可以参考笔者的另一篇博客:Sliding Window Attention(滑动窗口注意力)解析: Pytorch实现并结合全局注意力(Global Attention )


2. Gemma 2 中的滑动窗口注意力实现

Gemma 2 的 GemmaAttention 代码中,滑动窗口注意力的实现如下:

if (
    self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
    and self.sliding_window_size is not None
):
    all_ones = torch.ones_like(mask)
    sliding_mask = torch.triu(
        all_ones, -1 * self.sliding_window_size + 1
    ) * torch.tril(all_ones, self.sliding_window_size - 1)
    mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)

3. 代码解析

这段代码的作用是 创建滑动窗口掩码(Sliding Mask),确保每个 token 只能关注窗口范围内的 tokens

(1) 创建全 1 矩阵

all_ones = torch.ones_like(mask)
  • 生成一个与 mask 形状相同的全 1 矩阵 all_ones
  • 这个矩阵用于构建滑动窗口的掩码。

示例
假设 mask 形状为 [batch, num_heads, seq_len, seq_len],即:

all_ones = torch.tensor([
    [[1, 1, 1, 1, 1],  
     [1, 1, 1, 1, 1],  
     [1, 1, 1, 1, 1],  
     [1, 1, 1, 1, 1],  
     [1, 1, 1, 1, 1]]
])

(2) 生成滑动窗口掩码

sliding_mask = torch.triu(
    all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
  • torch.triu(matrix, diagonal) 创建上三角矩阵,保留对角线及以上的元素。
  • torch.tril(matrix, diagonal) 创建下三角矩阵,保留对角线及以下的元素。
  • 两者相乘,得到 滑动窗口范围内的注意力掩码
示例

假设 sliding_window_size = 2

torch.triu(all_ones, -1 * 2 + 1)

参考笔者的博客:详解 torch.triu:上三角矩阵的高效构造(中英双语)
生成:

[[1, 1, 1, 1, 1],  
 [1, 1, 1, 1, 1],  
 [0, 1, 1, 1, 1],  
 [0, 0, 1, 1, 1],  
 [0, 0, 0, 1, 1]]
torch.tril(all_ones, 2 - 1)

生成:

[[1, 1, 0, 0, 0],  
 [1, 1, 1, 0, 0],  
 [1, 1, 1, 1, 0],  
 [1, 1, 1, 1, 1],  
 [1, 1, 1, 1, 1]]

相乘后:

[[0, 1, 0, 0, 0],  
 [1, 1, 1, 0, 0],  
 [0, 1, 1, 1, 0],  
 [0, 0, 1, 1, 1],  
 [0, 0, 0, 1, 1]]
  • 1 表示保留的注意力连接(窗口内)。
  • 0 表示屏蔽的部分(窗口外)。

(3) 使用 torch.where() 让窗口外的注意力分数变成 -inf

mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
  • 如果 sliding_mask[i, j] == 1,则 mask[i, j] 保持不变
  • 如果 sliding_mask[i, j] == 0,则 mask[i, j] = -inf,确保 softmax 计算时这些位置的权重变成 0。

4. 滑动窗口注意力的优化效果

✅ 计算复杂度降低

  • 标准 Transformer:( O ( n 2 ) O(n^2) O(n2) ) 计算量,无法处理长文本。
  • Sliding Window Attention:( O ( n ⋅ w ) O(n \cdot w) O(nw) ) 计算量,适用于 8K+ 长文本。

✅ 内存占用减少

  • 标准 Transformer:完整存储 ( n × n n \times n n×n ) 注意力矩阵,显存消耗巨大。
  • Sliding Window Attention:只存储 ( n × w n \times w n×w ) 相关的注意力分数,显存占用减少 ( O ( n ) O(n) O(n) ) 级别。

✅ 局部性优化

  • 适用于长文本任务,如 摘要、代码生成、长文问答
  • 保证计算效率的同时,不影响任务表现

5. 结论

💡 Gemma 2 采用 Sliding Window Attention,使得模型能够处理更长文本,同时保持计算效率!

  • 通过 滑动窗口掩码(Sliding Mask),限制 token 只能关注局部窗口范围内的 tokens。
  • 计算复杂度降低为 ( O ( n ⋅ w ) O(n \cdot w) O(nw) ),适用于 8K-16K 长文本任务。
  • 结合 局部注意力(Local Attention),确保计算高效,适用于 文档分析、QA、代码生成等任务

🚀 Gemma 2 结合 Sliding Window Attention,是长文本 Transformer 模型优化的重要方向之一!

后记

2025年2月23日14点44分于上海,在GPT4o大模型辅助下完成。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值