- Attention 注意力机制:将输入向量转换为增强的上下文向量。
下图是将在本文编写的四种注意力机制。
长序列建模问题
引入:翻译由于源语言和目标语言的语法结构,我们不能简单地逐字翻译文本
使用两个DNN模块:编码器和解码器。
- 在Transformer架构出现之前,RNN是最流行的编码器和解码器架构
- 编码器将整个输入处理成隐藏层状态,解码器仅依据此生成输出,无法从编码器访问早期的隐藏状态。可能会导致跨长句子依赖关系的丢失。
使用注意力机制捕获数据依赖关系
解码器可以有选择的访问所有输入令牌。
自注意力机制
定义
允许输入序列的每个位置在计算序列的表示时关注同一序列的所有位置。
目标
为每个输入元素,计算一个上下文向量。
作用
通过整合序列中所有其他元素的信息,为输入序列中的每一个元素创造出更丰富的表征。
简单自注意力
步骤
- 计算注意力权重(attention weights):通过输入向量
- 通过计算与每个输入的点积来计算:
attn_scores = inputs @ inputs.T
- 然后再进行softmax归一化得到:
attn_weights = torch.softmax(attn_scores, dim=1)
- 通过计算与每个输入的点积来计算:
- 计算上下文向量(context vector):通过注意力权重和输入向量
attn_weights @ inputs
exp
exp - 为输入元素
x
(
2
)
x^{(2)}
x(2) 计算上下文向量
z
(
2
)
z^{(2)}
z(2) :
通过矩阵乘法计算代码如下:
# 注意力得分:直接计算输入向量之间的点积
attn_scores = inputs @ inputs.T
# 注意力权重:归一化注意力得分
attn_weights = torch.softmax(attn_scores, dim=1)
all_context_vecs = attn_weights @ inputs
- 通过
attn_scores = inputs @ inputs.T
计算注意力得分
带有可训练权重的自注意力
通过三个权重矩阵来转换输入向量。
- 与简单自注意力的区别:模型训练期间会更新权重矩阵,使得模型能够学习“良好”的上下文向量。
步骤
- 初始查询
W
q
W_q
Wq ,键
W
k
W_k
Wk,值
W
v
W_v
Wv 权重矩阵。
- 然后
input
分别矩阵乘法 @
三个权重矩阵,得到queries
,keys
,values
三个向量。
- 然后
- 计算注意力权重:通过查询 queries 和键 keys 向量
- 不是简单自注意力直接计算输入向量之间的点积,而是通过:
attn_scores = queries @ keys.T
- 缩放点积归一化
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
- 不是简单自注意力直接计算输入向量之间的点积,而是通过:
- 计算上下文向量:通过注意力权重和值 values 向量
context_vec = attn_weights @ values
exp - 为输入元素
x
(
2
)
x^{(2)}
x(2) 计算上下文向量
z
(
2
)
z^{(2)}
z(2) :
类比查询,键,值向量
- 查询向量 W q W_q Wq :类似于数据库中的搜索查询。表示模型当前关注或试图理解的项目。
- 键向量 W k W_k Wk :类似于数据库中的索引和搜索的键。输入序列中的每个项目都有一个关联的键。
- 值向量 W v W_v Wv :类似于数据库中键值对的值。一旦模型确定哪些输入部分(键)与当前项目(查询)最相关,它就检索相应的值。
实现 Self-Attention 类
# 导入PyTorch的神经网络模块
import torch.nn as nn
# 继承自nn.Module
class SelfAttenton_v2(nn.Module):
# qkv_bias:偏置向量,以提供额外的灵活性和模型的表达能力
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.d_out = d_out
# v1版本: self.W_query = nn.Parameter(torch.rand(d_in, d_out))
# nn.Linear 采用了比 nn.Parameter 更为复杂的权重初始化方案,且以转置形式存储权重矩阵
# 定义查询(Query)、键(Key)和值(Value)的线性变换层
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# 前向传播函数,定义了自注意力机制的计算流程
def forward(self, x):
# 对输入x应用键(Key)、查询(Query)和值(Value)的线性变换
# W_key(x) == x @ W_key
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# 计算查询和键的注意力分数
attn_scores = queries @ keys.T # 使用矩阵乘法计算注意力分数
# 应用softmax函数对注意力分数进行归一化,得到注意力权重
# 除以键的维度的平方根是为了进行缩放,防止梯度消失或爆炸
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5, dim = -1)
# 使用注意力权重和值(Value)计算上下文向量
context_vec = attn_weights @ values # 使用矩阵乘法计算加权和
# 返回计算得到的上下文向量
return context_vec
- 引入单组查询 W q W_q Wq ,键 W k W_k Wk,值 W v W_v Wv 权重矩阵,
- 通过
attn_scores = queries @ keys.T
计算注意力得分
使用因果注意力机制隐藏后续词
只考虑序列中当前Token或之前出现的Token。
自然想到引入遮蔽矩阵(下三角矩阵)
步骤
- 前面与[[3-2带有可训练权重的自注意力#步骤|3-2]]得到注意力权重步骤相同。
- 对每个当前token后的注意力权重进行遮蔽,两种遮蔽方式:
- [[3-3因果注意力#遮蔽方式一]]
- [[3-3因果注意力#遮蔽方式二]]
- 计算上下文向量:通过注意力权重和值 values 向量,与带有可训练权重的自注意力相同
context_vec = attn_weights @ values
遮蔽方式一
attn_weights * mask_simple
与遮蔽矩阵(下三角矩阵)做乘法- 再归一化
# 引入遮蔽矩阵(下三角矩阵):
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
# 将遮蔽与注意力权重相乘,将对角线以上的值归零:
masked_simple = attn_weights * mask_simple
# 再归一化:
row_sums = masked_simple.sum(dim=1, keepdim=True) # 求每一行的和
masked_simple_norm = masked_simple / row_sums
伪 - 信息泄漏问题
可能出现打算遮蔽的Token仍影响当前Token,因它们的值时softmax函数计算的一部分。
但是softmax的数学优雅之处在于,尽管在最初的计算中分母包含了所有位置,但在遮蔽和重新归一化之后,被遮蔽的位置的影响被消除了
遮蔽方式二
利用softmax函数特性:负无穷趋近于0。
# mask(上三角矩阵),主对角线为0
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
# 填充attn_scores张量中的上三角部分为负无穷
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
引入Dropout防止过拟合
在训练过程中随机忽略选定的隐藏层单元(确保模型不会过度依赖任何特定的隐藏层单元),有效地丢弃,以防止过拟合。
# 例子:
torch.manual_seed(123)
# 为了补偿活跃元素的减少,矩阵中剩余元素的值被放大了 1/0.5 = 2 倍
dropout = torch.nn.Dropout(0.5) # 丢弃率50%
attn_weights = dropout(attn_weights)
实现 Casual Attention 类
import torch.nn as nn
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout)
# register_buffer 不需要手动确保这些张量与模型参数在同一设备上,从而避免设备不匹配错误
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# 将keys的第二维和第三维进行交换,下面有说明!!!
attn_scores = queries @ keys.transpose(1, 2)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
context_vec = attn_weights @ values
return context_vec
keys.transpose(1, 2)说明
进行第二维和第三维交换:
变量 | 维度 |
---|---|
inputs | ([b, num_tokens, d_in]) |
W_ke, W_query, W_value | ([d_in, d_out]) |
keys, queries, values | ([b, num_tokens, d_out]) |
attn_scores, attn_weights | ([b, num_tokens, num_tokens]) |
context_vec | ([b, num_tokens, d_out]) |
-
queries 为
([b, num_tokens, d_out])
-
keys为
([b, num_tokens, d_out])
第二维和第三维交换后([b, d_out, num_tokens])
然后才对齐执行矩阵乘法,得到attn_scores([b, num_tokens, num_tokens])
-
对注意力权重进行遮蔽(masked attention)
多头注意力
- 多头:将注意力机制分为多个“头”,每个头独立运作。单个因果注意力模块可以被视为单头注意力,其中只有一组注意力权重顺序处理输入。
⚠️核心:多组查询 W q W_q Wq ,键 W k W_k Wk,值 W v W_v Wv 权重矩阵。
单头:
多头:
实现
顺序处理
- 缺点:需在forward方法中
[head(x) for head in self.heads]
顺序处理。
class MultiHeadAttentionWrapper(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, num_heads, qkv_bias=False):
super().__init__()
self.heads = nn.ModuleList(
[CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
for _ in range(num_heads)]
)
def forward(self, x):
return torch.cat([head(x) for head in self.heads], dim=-1)
并行处理
矩阵乘法代替for循环:
关键操作是将 d_out 维度分割为 num_heads 和 head_dim,其中 head_dim = d_out / num_heads。这种分割随后通过 .view 方法实现:将维度为 (b, num_tokens, d_out) 的张量重塑为维度 (b, num_tokens, num_heads, head_dim)
- 优点:只需要一次矩阵乘法就可以计算出键。
带维度注释
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out,
context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" # A
self.d_out = d_out
self.num_heads = num_heads # A
self.head_dim = d_out // num_heads # A
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # B
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)
# 将[(b, num_tokens, d_out)] 重塑为 [(b, num_tokens, num_heads, num_dim)]
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
# 将[(b, num_tokens, num_heads, num_dim)] 转置为 [(b, num_heads, num_tokens, num_dim)]
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# [(b, num_heads, num_tokens, num_dim)] @ [(b, num_heads, num_dim, num_tokens)]
# => attn_scores: [(b, num_heads, num_tokens, num_tokens)]
attn_scores = queries @ keys.transpose(2, 3)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# [(b, num_heads, num_tokens, num_tokens)] @ [(b, num_heads, num_tokens, num_dim)]
# =>[(b, num_heads, num_tokens, num_dim)] 再转置
# => context_vec: [(b, num_tokens, num_heads, num_dim)]
context_vec = (attn_weights @ values).transpose(1, 2)
# [(b, num_tokens, num_heads, num_dim)] 重塑为 [(b, num_tokens, d_out)]
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out) # F
context_vec = self.out_proj(context_vec) # F
return context_vec
- 多组查询 W q W_q Wq ,键 W k W_k Wk,值 W v W_v Wv 权重矩阵。