前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。
https://www.captainbed.cn/north
文章目录
1. 引言:注意力机制的崛起
在深度学习发展历程中,Transformer架构及其核心的自注意力机制(Self-Attention)无疑是最具革命性的突破之一。2017年Google发表的《Attention Is All You Need》论文彻底改变了自然语言处理(NLP)的格局,并迅速扩展到计算机视觉、语音识别、生物信息学等多个领域。本文将深入剖析自注意力机制的工作原理、数学基础、实现细节,并探讨它为何能如此深刻地改变人工智能的发展轨迹。
2. 自注意力机制的核心思想
2.1 基本概念
自注意力机制是一种允许输入序列的每个元素(如句子中的单词)与序列中所有其他元素进行交互并计算其相对重要性的机制。与传统序列模型(如RNN、LSTM)相比,它具有三大核心优势:
- 全局依赖性:直接建模任意距离元素间的关系
- 并行计算:摆脱了RNN的序列计算约束
- 动态权重:根据输入内容动态调整关注权重
2.2 直观理解
想象阅读一段文字时,人类会自然地关注与当前理解最相关的其他词语。例如:
“动物没有过马路,因为它太累了”
人类会自然地建立"它"与"动物"之间的联系而非"马路"。自注意力机制正是模拟这种动态的、基于内容的关联建模能力。
3. 自注意力机制的数学原理
3.1 关键概念定义
给定输入序列 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d(n个token,每个维度为d),自注意力机制通过三个可学习矩阵计算:
- 查询(Query): Q = X W Q Q = XW_Q Q=XWQ, W Q ∈ R d × d k W_Q \in \mathbb{R}^{d \times d_k} WQ∈Rd×dk
- 键(Key): K = X W K K = XW_K K=XWK, W K ∈ R d × d k W_K \in \mathbb{R}^{d \times d_k} WK∈Rd×dk
- 值(Value): V = X W V V = XW_V V=XWV, W V ∈ R d × d v W_V \in \mathbb{R}^{d \times d_v} WV∈Rd×dv
其中 d k d_k dk是key/query的维度, d v d_v dv是value的维度。
3.2 注意力计算步骤
-
相似度计算:计算query与所有key的点积
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V -
缩放点积:除以 d k \sqrt{d_k} dk防止梯度消失
-
Softmax归一化:得到注意力权重
-
加权求和:用权重对value进行加权
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
return torch.matmul(p_attn, V), p_attn
3.3 多头注意力(Multi-Head Attention)
为了捕捉不同子空间的特征,Transformer使用多头注意力:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1,...,\text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个head的计算为:
head
i
=
Attention
(
Q
W
i
Q
,
K
W
i
K
,
V
W
i
V
)
\text{head}_i = \text{Attention}(QW_i^Q,KW_i^K,VW_i^V)
headi=Attention(QWiQ,KWiK,VWiV)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_k = d_model // num_heads
self.num_heads = num_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换并分割头
Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# 计算注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# 合并头并输出
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.num_heads * self.d_k)
return self.W_O(attn_output)
4. 为什么自注意力改变了AI?
4.1 革命性的架构优势
特性 | RNN/LSTM | Transformer |
---|---|---|
长程依赖 | 困难(梯度消失) | 直接建模 |
并行计算 | 不可并行 | 完全并行 |
计算复杂度 | O(n) | O(n²) |
信息流动 | 顺序 | 全连接 |
实际表现 | 受限于序列长度 | 超长序列表现良好 |
4.2 关键突破点
- 并行化训练:摆脱了RNN的序列依赖性
- 全局上下文:每个位置直接访问所有位置信息
- 表示能力:动态权重比固定架构更灵活
- 可扩展性:适合大规模预训练
4.3 跨领域影响
- NLP领域:BERT、GPT、T5等模型统治各类任务
- 计算机视觉:Vision Transformer (ViT) 超越CNN
- 多模态学习:CLIP、DALL-E等跨模态模型
- 科学计算:AlphaFold2解决蛋白质折叠问题
5. Transformer完整架构
5.1 编码器-解码器结构
graph TD
A[输入序列] --> B[编码器]
B --> C[解码器]
C --> D[输出序列]
subgraph 编码器
B --> E[N×编码器层]
E --> F[多头自注意力]
F --> G[前馈网络]
G --> H[残差连接+层归一化]
end
subgraph 解码器
C --> I[N×解码器层]
I --> J[掩码多头自注意力]
J --> K[多头编码-解码注意力]
K --> L[前馈网络]
L --> M[残差连接+层归一化]
end
5.2 关键组件实现
class TransformerLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.feed_forward = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力子层
attn_output = self.self_attn(x, x, x, mask)
x = x + self.dropout(attn_output)
x = self.norm1(x)
# 前馈子层
ff_output = self.feed_forward(x)
x = x + self.dropout(ff_output)
x = self.norm2(x)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:x.size(1)]
6. 自注意力的变体与改进
6.1 稀疏注意力
解决O(n²)复杂度问题:
- 局部注意力:限制注意力窗口大小
- 轴向注意力:分别处理不同维度
- 稀疏变换器:预定义稀疏模式
6.2 高效注意力
- 线性注意力:将softmax近似为核函数
Sim ( Q , K ) = ϕ ( Q ) ϕ ( K ) T \text{Sim}(Q,K) = \phi(Q)\phi(K)^T Sim(Q,K)=ϕ(Q)ϕ(K)T - 内存压缩:聚类或降维key/value
- 分块计算:将长序列分块处理
6.3 相对位置编码
原始Transformer使用绝对位置编码,改进方案:
- 相对位置偏置:在注意力分数中加入相对距离项
- 旋转位置编码:RoPE (Rotary Position Embedding)
class RotaryPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, seq_len, device):
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
return torch.cat((freqs, freqs), dim=-1)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, pos_emb):
cos, sin = pos_emb.cos(), pos_emb.sin()
q = (q * cos) + (rotate_half(q) * sin)
k = (k * cos) + (rotate_half(k) * sin)
return q, k
7. 自注意力机制的应用案例
7.1 NLP领域:GPT-3
- 纯解码器架构:仅使用Transformer解码器
- 自回归生成:逐个token预测
- 零样本学习:1750亿参数实现强大泛化
7.2 计算机视觉:Vision Transformer
class ViT(nn.Module):
def __init__(self, image_size, patch_size, num_classes, d_model, num_heads, num_layers):
super().__init__()
num_patches = (image_size // patch_size) ** 2
self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))
self.transformer = nn.ModuleList([
TransformerLayer(d_model, num_heads, d_model*4) for _ in range(num_layers)
])
self.head = nn.Linear(d_model, num_classes)
def forward(self, x):
# 分块嵌入
x = self.patch_embed(x) # [B, C, H, W] -> [B, D, H/P, W/P]
x = x.flatten(2).transpose(1, 2) # [B, N, D]
# 添加[CLS] token和位置编码
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
# Transformer编码器
for layer in self.transformer:
x = layer(x)
# 分类头
return self.head(x[:, 0])
7.3 多模态学习:CLIP
- 双编码器架构:图像和文本分别编码
- 对比学习:最大化匹配对的相似度
- 注意力交互:跨模态注意力机制
8. 自注意力机制的局限性
- 计算复杂度:O(n²)内存和计算需求
- 长序列处理:尽管优于RNN,但仍面临挑战
- 训练难度:需要大规模数据和计算资源
- 解释性:注意力权重不一定反映真实重要性
9. 未来发展方向
- 高效注意力:突破平方复杂度限制
- 动态架构:根据输入调整计算路径
- 神经符号结合:融合符号推理能力
- 生物启发改进:借鉴人脑注意力机制
10. 结论
自注意力机制之所以能深刻改变AI领域,核心在于它提供了一种灵活、并行、全局的信息处理范式,突破了传统序列模型的根本限制。从理论上看,它实际上是一种基于内容的内存寻址机制,可以看作现代计算机"随机访问内存"概念在神经网络中的实现。随着研究的深入,自注意力机制将继续演化,推动人工智能向更通用、更高效的方向发展。
附录:自注意力完整计算流程图
graph LR
A[输入X] --> B[计算Q,K,V]
B --> C[Q×K^T]
C --> D[Scale:除以√d_k]
D --> E[Softmax归一化]
E --> F[加权求和V]
F --> G[输出]
subgraph 多头注意力
B --> H[分割多头]
H --> C
F --> I[合并多头]
end