📁 项目目录结构
transformer-tutorial/
├── data/
│ └── sample_text.txt
├── model/
│ ├── attention.py
│ ├── encoder.py
│ ├── transformer.py
├── train.py
├── predict.py
└── requirements.txt
✅ 环境准备
python -m venv transformer-env
source transformer-env/bin/activate
pip install torch numpy tqdm matplotlib
🧠 Step 1: 构建位置编码(Positional Encoding)
import torch
import math
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
return x + self.pe[:, :x.size(1)]
🔁 Step 2: 实现自注意力机制(Self-Attention)
import torch
import torch.nn as nn
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super().__init__()
self.scale = d_k ** 0.5
def forward(self, Q, K, V, mask=None):
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn = torch.nn.functional.softmax(scores, dim=-1)
output = torch.matmul(attn, V)
return output, attn
💥 Step 3: 构建多头注意力(Multi-Head Attention)
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.d_k = self.d_v = d_model // num_heads
self.num_heads = num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attn = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V, mask=None):
B, T, D = Q.size()
Q = self.W_q(Q).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(B, T, self.num_heads, self.d_k).transpose(1, 2)
out, attn = self.attn(Q, K, V, mask)
out = out.transpose(1, 2).contiguous().view(B, T, D)
return self.W_o(out)
🧱 Step 4: 编码器模块(Encoder Layer)
import torch.nn as nn
from model.attention import MultiHeadAttention
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, dim_ff, dropout=0.1):
super().__init__()
self.attn = MultiHeadAttention(d_model, num_heads)
self.ff = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.ReLU(),
nn.Linear(dim_ff, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
x2 = self.attn(x, x, x, mask)
x = self.norm1(x + self.dropout(x2))
x2 = self.ff(x)
x = self.norm2(x + self.dropout(x2))
return x
🔄 Step 5: 构建完整 Transformer 模型
from model.encoder import TransformerEncoderLayer
from model.positional_encoding import PositionalEncoding
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, d_model=512, num_heads=8, num_layers=6, dim_ff=2048, max_len=512):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoder = PositionalEncoding(d_model, max_len)
self.encoder_layers = nn.ModuleList([
TransformerEncoderLayer(d_model, num_heads, dim_ff) for _ in range(num_layers)
])
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src):
x = self.embedding(src)
x = self.pos_encoder(x)
for layer in self.encoder_layers:
x = layer(x)
return self.fc(x)
⚙️ 模型训练样例代码(train.py)
import torch
from model.transformer import SimpleTransformer
model = SimpleTransformer(vocab_size=10000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()
for epoch in range(10):
input_ids = torch.randint(0, 10000, (32, 128))
labels = input_ids.clone()
output = model(input_ids)
loss = loss_fn(output.view(-1, 10000), labels.view(-1))
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f"Epoch {epoch} Loss: {loss.item():.4f}")
📈 推理代码(predict.py)
from model.transformer import SimpleTransformer
import torch
model = SimpleTransformer(vocab_size=10000)
model.eval()
input_seq = torch.randint(0, 10000, (1, 20))
with torch.no_grad():
output = model(input_seq)
pred = output.argmax(dim=-1)
print("输入:", input_seq)
print("预测:", pred)