Gemma 2 的滑动窗口注意力(Sliding Window Attention)解析
在 Transformer 结构 中,自注意力(Self-Attention)是核心机制之一。然而,标准的自注意力计算复杂度为 ( O ( n 2 ) O(n^2) O(n2) ),随着序列长度增加,计算和内存开销都会急剧增长。为了解决这一问题,Gemma 2 采用了滑动窗口注意力(Sliding Window Attention)机制,通过限制每个 token 只能关注附近的 tokens,大幅降低计算复杂度,使模型能够高效处理更长的文本序列。
本文将详细解析 Gemma 2 的 Sliding Window Attention 的实现原理、代码细节以及它带来的优化效果。
1. 为什么需要滑动窗口注意力?
1.1 标准 Transformer 的计算瓶颈
在标准 Transformer 中,自注意力计算如下:
Attention
(
Q
,
K
,
V
)
=
softmax
(
Q
K
T
d
k
)
V
\text{Attention}(Q, K, V) = \text{softmax} \left(\frac{QK^T}{\sqrt{d_k}} \right) V
Attention(Q,K,V)=softmax(dkQKT)V
其中:
- ( Q , K , V Q, K, V Q,K,V ) 是 查询(Query)、键(Key)、值(Value) 矩阵,形状为 ( ( n , d k ) (n, d_k) (n,dk) )。
- 计算复杂度为 ( O ( n 2 d k ) O(n^2 d_k) O(n2dk) ),其中 ( n n n ) 是序列长度。
问题:
- 对于 长文本(如 8K、16K token),标准 Transformer 计算量和显存占用急剧上升。
- 由于 每个 token 计算所有 token 的注意力分数,使得 全自注意力不适用于长文本任务。
1.2 Sliding Window Attention 如何优化?
滑动窗口注意力(Sliding Window Attention)核心思想:
- 每个 token 仅关注窗口范围内的 tokens,而不是整个序列。
- 计算复杂度降低为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ),其中 ( w w w ) 是窗口大小。
- 适用于 长文本任务,如文档摘要、问答、代码补全等。
📌 示例
- 如果窗口大小 ( w = 512 ):
- 第 1000 个 token 仅关注
[488, ..., 1000, ..., 1512]
范围内的 tokens。 - 而不是像标准 Transformer 一样计算
[0, ..., 1000, ..., n]
之间的所有 token 交互。
- 第 1000 个 token 仅关注
🔹 这样,每个 token 只计算窗口内 tokens 的注意力,大幅减少计算量!
可以参考笔者的另一篇博客:Sliding Window Attention(滑动窗口注意力)解析: Pytorch实现并结合全局注意力(Global Attention )
2. Gemma 2 中的滑动窗口注意力实现
在 Gemma 2 的 GemmaAttention
代码中,滑动窗口注意力的实现如下:
if (
self.attn_type == gemma_config.AttentionType.LOCAL_SLIDING
and self.sliding_window_size is not None
):
all_ones = torch.ones_like(mask)
sliding_mask = torch.triu(
all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
3. 代码解析
这段代码的作用是 创建滑动窗口掩码(Sliding Mask),确保每个 token 只能关注窗口范围内的 tokens。
(1) 创建全 1 矩阵
all_ones = torch.ones_like(mask)
- 生成一个与
mask
形状相同的全 1 矩阵all_ones
。 - 这个矩阵用于构建滑动窗口的掩码。
示例
假设 mask
形状为 [batch, num_heads, seq_len, seq_len]
,即:
all_ones = torch.tensor([
[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]]
])
(2) 生成滑动窗口掩码
sliding_mask = torch.triu(
all_ones, -1 * self.sliding_window_size + 1
) * torch.tril(all_ones, self.sliding_window_size - 1)
torch.triu(matrix, diagonal)
创建上三角矩阵,保留对角线及以上的元素。torch.tril(matrix, diagonal)
创建下三角矩阵,保留对角线及以下的元素。- 两者相乘,得到 滑动窗口范围内的注意力掩码。
示例
假设 sliding_window_size = 2
:
torch.triu(all_ones, -1 * 2 + 1)
参考笔者的博客:详解 torch.triu:上三角矩阵的高效构造(中英双语)
生成:
[[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1],
[0, 1, 1, 1, 1],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1]]
torch.tril(all_ones, 2 - 1)
生成:
[[1, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[1, 1, 1, 1, 0],
[1, 1, 1, 1, 1],
[1, 1, 1, 1, 1]]
相乘后:
[[0, 1, 0, 0, 0],
[1, 1, 1, 0, 0],
[0, 1, 1, 1, 0],
[0, 0, 1, 1, 1],
[0, 0, 0, 1, 1]]
1
表示保留的注意力连接(窗口内)。0
表示屏蔽的部分(窗口外)。
(3) 使用 torch.where()
让窗口外的注意力分数变成 -inf
mask = torch.where(sliding_mask == 1, mask, -2.3819763e38)
- 如果
sliding_mask[i, j] == 1
,则mask[i, j]
保持不变。 - 如果
sliding_mask[i, j] == 0
,则mask[i, j] = -inf
,确保 softmax 计算时这些位置的权重变成 0。
4. 滑动窗口注意力的优化效果
✅ 计算复杂度降低
- 标准 Transformer:( O ( n 2 ) O(n^2) O(n2) ) 计算量,无法处理长文本。
- Sliding Window Attention:( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ) 计算量,适用于 8K+ 长文本。
✅ 内存占用减少
- 标准 Transformer:完整存储 ( n × n n \times n n×n ) 注意力矩阵,显存消耗巨大。
- Sliding Window Attention:只存储 ( n × w n \times w n×w ) 相关的注意力分数,显存占用减少 ( O ( n ) O(n) O(n) ) 级别。
✅ 局部性优化
- 适用于长文本任务,如 摘要、代码生成、长文问答。
- 保证计算效率的同时,不影响任务表现。
5. 结论
💡 Gemma 2 采用 Sliding Window Attention,使得模型能够处理更长文本,同时保持计算效率!
- 通过 滑动窗口掩码(Sliding Mask),限制 token 只能关注局部窗口范围内的 tokens。
- 计算复杂度降低为 ( O ( n ⋅ w ) O(n \cdot w) O(n⋅w) ),适用于 8K-16K 长文本任务。
- 结合 局部注意力(Local Attention),确保计算高效,适用于 文档分析、QA、代码生成等任务。
🚀 Gemma 2 结合 Sliding Window Attention,是长文本 Transformer 模型优化的重要方向之一!
后记
2025年2月23日14点44分于上海,在GPT4o大模型辅助下完成。