文章目录
1. 引言
Transformer模型由Google团队在2017年发表的论文《Attention Is All You Need》中首次提出,彻底改变了自然语言处理领域的格局。相较于传统的RNN和CNN模型,Transformer具有以下核心优势:
- 并行计算能力:克服了RNN的序列依赖问题
- 长距离依赖建模:通过自注意力机制捕捉全局关系
- 可扩展性:适合构建超大规模预训练模型
2. 整体架构
2.1 架构总览
Transformer采用经典的编码器-解码器结构,包含N个相同的编码器层和解码器层(原论文中N=6)
class Transformer(nn.Module):
def __init__(self, src_vocab, tgt_vocab, d_model=512, N=6,
heads=8, dropout=0.1):
super().__init__()
self.encoder = Encoder(src_vocab, d_model, N, heads, dropout)
self.decoder = Decoder(tgt_vocab, d_model, N, heads, dropout)
self.out = nn.Linear(d_model, tgt_vocab)
2.2 编码器-解码器结构对比
组件 | 编码器 | 解码器 |
---|---|---|
注意力机制 | 自注意力 | 掩码自注意力 + 编码器-解码器注意力 |
前馈网络 | Position-wise FFN | Position-wise FFN |
位置编码 | 正弦位置编码 | 正弦位置编码 |
层数量 | N(通常6层) | N(通常6层) |
3. 核心组件详解
3.1 自注意力机制
3.1.1 计算过程
- 将输入转换为Q(Query), K(Key), V(Value)三个矩阵
- 计算注意力分数: A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dkQKT)V
def attention(q, k, v, d_k, mask=None, dropout=None):
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
scores = F.softmax(scores, dim=-1)
if dropout is not None:
scores = dropout(scores)
output = torch.matmul(scores, v)
return output
3.1.2 多头注意力
class MultiHeadAttention(nn.Module):
def __init__(self, heads, d_model, dropout=0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // heads
self.h = heads
self.q_linear = nn.Linear(d_model, d_model)
self.k_linear = nn.Linear(d_model, d_model)
self.v_linear = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(d_model, d_model)
3.2 位置编码
正弦余弦函数实现:
P
E
(
p
o
s
,
2
i
)
=
sin
(
p
o
s
/
1000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos,2i)} = \sin(pos/10000^{2i/d_{model}})
PE(pos,2i)=sin(pos/100002i/dmodel)
P
E
(
p
o
s
,
2
i
+
1
)
=
cos
(
p
o
s
/
1000
0
2
i
/
d
m
o
d
e
l
)
PE_{(pos,2i+1)} = \cos(pos/10000^{2i/d_{model}})
PE(pos,2i+1)=cos(pos/100002i/dmodel)
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len=200):
super().__init__()
self.d_model = d_model
pe = torch.zeros(max_seq_len, d_model)
position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
3.3 前馈网络
位置式前馈网络:
F
F
N
(
x
)
=
m
a
x
(
0
,
x
W
1
+
b
1
)
W
2
+
b
2
FFN(x) = max(0, xW_1 + b_1)W_2 + b_2
FFN(x)=max(0,xW1+b1)W2+b2
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff=2048, dropout=0.1):
super().__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x
4. 数学原理详解
4.1 自注意力公式推导
给定输入矩阵 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d,计算:
Q
=
X
W
Q
,
K
=
X
W
K
,
V
=
X
W
V
Q = XW^Q, K = XW^K, V = XW^V
Q=XWQ,K=XWK,V=XWV
A
t
t
e
n
t
i
o
n
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
V
Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V
Attention(Q,K,V)=softmax(dkQKT)V
其中 W Q , W K ∈ R d × d k W^Q, W^K \in \mathbb{R}^{d \times d_k} WQ,WK∈Rd×dk, W V ∈ R d × d v W^V \in \mathbb{R}^{d \times d_v} WV∈Rd×dv
4.2 梯度流动分析
通过残差连接保证梯度畅通:
L
a
y
e
r
N
o
r
m
(
x
+
S
u
b
l
a
y
e
r
(
x
)
)
LayerNorm(x + Sublayer(x))
LayerNorm(x+Sublayer(x))
5. 完整实现代码
# 编码器层实现
class EncoderLayer(nn.Module):
def __init__(self, d_model, heads, dropout=0.1):
super().__init__()
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.attn = MultiHeadAttention(heads, d_model)
self.ff = FeedForward(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
6. 训练与推理
6.1 训练流程
6.2 推理优化技巧
- Beam Search
- 长度惩罚
- 温度采样
7. 应用与扩展
- BERT:仅使用编码器
- GPT:仅使用解码器
- Transformer-XL:处理更长序列
- Vision Transformer:计算机视觉应用
8. 总结
Transformer通过完全基于注意力机制的架构,解决了传统序列模型的根本性限制。其核心创新点包括:
- 并行化的自注意力计算
- 位置编码方案
- 残差连接和层归一化
- 可堆叠的模块化设计
随着大模型时代的到来,Transformer架构仍在持续进化,不断推动着AI技术的发展边界。