模型训练通常使用 批处理batch来提升训练效率。而实际中Transformer的输入序列(如句子、文本片段)长度往往不一致。为了让这些样本可以组成一个统一的形状 [B, T]
的张量给GPU进行并行计算提高效率,需要将较短的序列填充(pad)到最大长度,比如统一为 T=10
。填充内容一般是特殊的 <PAD>
符号,或者对应的 token ID(如 0)。
Transformer在训练和推理时,padding的作用不同。训练阶段,为并行处理多个长度不一的序列,需用padding统一长度,确保模型能一次性处理整个批次。推理阶段,单序列生成时无需padding,模型逐token输出;但若采用批处理推理,仍需padding对齐不同长度的序列。因此,训练时padding必不可少,推理时视处理方式而定。
整个 Transformer 结构中涉及到的 “掩码” 类型一共有两种:① 用于区分同一个 batch 中不同长度序列是否被填充的 key padding mask;② 在训练时,Decoder 中用于模仿推理过程中在编码当前时禁止看到未来信息的 attention mask(也叫做 casual mask 或 future mask)
对于自回归模型来说,我们在训练阶段为了加速模型的训练过程,都是直接使用小批量样本来进行训练,而这就导致了不同样本长度不能进行计算的问题,因而要使用填充操作来使同一个 batch 中的序列长度保持一致。同时,在自注意力机制中,为了防止注意力被分配到 padding 位置上,所以我们需要通过 padding mask 来忽略掉这部分注意力值。
名称 | 核心作用 | 应用模块 | 是否与batch相关 | 典型形状与值示例 | 使用场景 |
---|---|---|---|---|---|
key_padding_mask | 屏蔽无效的padding符号(如句末填充的占位符) | Encoder层、Decoder层的自注意力与交叉注意力 | ✅ 是(每个样本独立) | [B, T] 例: [0,0,-inf] | 处理变长序列时过滤padding |
causal mask (attention_mask) | 防止模型看到未来时刻的信息(保持时序依赖) | Decoder层的自注意力 | ❌ 否(全局共享) | [T, T] 例:上三角矩阵填充 -inf | 文本生成、自回归预测任务 |
- Query:代表当前生成任务的需求(如"接下来该生成什么词?"),通过当前隐藏状态线性变换生成
- Key:作为历史上下文的索引(如已生成的词语),用于与Query匹配相关性
- Value:存储实际语义内容,根据注意力权重提取关键信息
- 动态信息检索系统:Q提出问题 → K响应线索 → V提供语义内容
- 注意力权重计算:通过Q与K的点积相似度 → 缩放 → softmax归一化 → 加权聚合V
- 线性变换:每个输入位置的词向量通过独立权重矩阵生成Q/K/V
- 缩放点积:为防止梯度消失,需除以√d_k(d_k为向量维度)
- 多头注意力:分割为多个子空间(如8头),分别捕捉语法/语义等多维度关系
- 位置编码:通过正弦/余弦函数注入位置信息,保持序列顺序感知
- 自回归生成:通过因果掩码确保仅关注历史信息
- 注意力聚焦:生成"这里"时,Q高匹配"公众号"的K,提取其V的语义特征
- 概率预测:加权后的上下文向量经FFN输出词表概率分布
import torch
import torch.nn.functional as F
# ============== 输入设置 ==============
# 两个示例序列:
# 序列1: [1, 2, 0] (0是填充符号)
# 序列2: [3, 4, 5] (没有填充)
inputs = torch.tensor([[1, 2, 0], [3, 4, 5]])
print("输入序列:")
print(inputs)
# ============== 1. 填充掩码 ==============
print("\n== 第一部分: 填充掩码 ==")
# 填充掩码标识哪些位置包含填充符(0)
# True = 这个位置应被忽略
# False = 这个位置包含真实数据
key_padding_mask = (inputs == 0)
print("\n填充掩码 (True表示'忽略此位置'):")
print(key_padding_mask)
print("解释: 第一个序列的最后一位是填充符,所以标记为True")
# ============== 2. 因果掩码 ==============
print("\n== 第二部分: 因果掩码 ==")
# 因果掩码防止token关注未来位置
# 这对GPT等自回归模型至关重要
seq_len = inputs.size(1)
# 创建下三角矩阵(包括对角线),1表示可见,0表示被掩盖
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
# 转换为布尔掩码(0->True表示需掩盖,1->False表示可见)
causal_mask = causal_mask == 0
print("\n因果掩码 (True表示'需掩盖'):")
print(causal_mask)
print("解释: 每个位置只能看到自己和之前的位置")
# ============== 3. 应用掩码到注意力分数 ==============
print("\n== 第三部分: 应用掩码 ==")
# 模拟注意力分数
attn_scores = torch.randn(2, 3, 3)
print("\n原始注意力分数 (随机):")
print(attn_scores)
# 应用填充掩码
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1), # [2,3] -> [2,1,3]
-1e9 # 非常小的值,softmax后接近0
)
# 应用因果掩码
attn_scores = attn_scores.masked_fill(
causal_mask.unsqueeze(0), # [3,3] -> [1,3,3]
-1e9
)
# 计算最终注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
print("\n最终注意力权重:")
print(attn_weights)
print("解释: 权重已被两种掩码调整,填充位置和未来位置的权重接近0")
Transformer 中 Decoder 里的 tgt key padding mask 操作其实是不需要的。因为计算损失时只会考虑非padding的字符,padding字符不参与损失计算。参考:Transformer中Decoder真的不需要PaddingMask?QKV到底是怎么来的?
注意力机制和mask掩码示例代码
import torch
import torch.nn as nn
# 设置随机种子,确保实验结果可复现
torch.random.manual_seed(42)
# 定义模型维度和序列长度
d_model = 7 # 模型的维度
tgt_len = 6 # 目标序列长度
src_len = 5 # 源序列长度
# =================== 编码器部分 ===================
# 随机生成源序列的嵌入表示
src_em_input = torch.randn(src_len, d_model)
# 定义线性变换层用于计算Q、K、V
q_proj0 = nn.Linear(d_model, d_model)
k_proj0 = nn.Linear(d_model, d_model)
v_proj0 = nn.Linear(d_model, d_model)
# 计算Q、K、V
q, k, v = q_proj0(src_em_input), k_proj0(src_em_input), v_proj0(src_em_input)
# 创建源序列的填充掩码
src_key_padding_mask = torch.tensor([[False, False, False, True, True]])
# 计算注意力权重
weight = q @ k.transpose(0, 1) # [src_len, src_len]
print("Encoder attention weight before key padding:")
print(weight)
# 应用填充掩码
weight = weight.masked_fill(src_key_padding_mask, float('-inf'))
print("Encoder attention weight after key padding:")
print(weight)
# 计算softmax值
weight = torch.softmax(weight, dim=-1) # [src_len, src_len]
# 得到加权后的输出(记忆)
memory = weight @ v # [src_len, d_model]
# =================== 解码器部分 ===================
## ------------------ 带遮罩的自注意力机制
# 随机生成目标序列的嵌入表示
tgt_em_input = torch.randn(tgt_len, d_model)
# 创建目标序列的填充掩码
tgt_key_padding_mask = torch.tensor([[False, False, False, True, True, True]])
# 创建上三角矩阵作为遮罩,防止未来信息泄露
attn_mask = (torch.triu(torch.ones((tgt_len, tgt_len))) == 1).transpose(0, 1)
attn_mask = attn_mask.float().masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.))
# 定义线性变换层用于计算Q、K、V
q_proj1 = nn.Linear(d_model, d_model)
k_proj1 = nn.Linear(d_model, d_model)
v_proj1 = nn.Linear(d_model, d_model)
# 计算Q、K、V
q, k, v = q_proj1(tgt_em_input), k_proj1(tgt_em_input), v_proj1(tgt_em_input) # [tgt_len, d_model]
# 计算注意力权重
weight = q @ k.transpose(0, 1) # [tgt_len, tgt_len]
weight += attn_mask
print("Decoder masked attention weight before key padding:")
print(weight)
# 这里控制是否应用填充掩码
weight = weight.masked_fill(tgt_key_padding_mask, float('-inf'))
print("Decoder masked attention weight after key padding:")
print(weight)
weight = torch.softmax(weight, dim=-1) # [tgt_len, tgt_len]
output = weight @ v # [tgt_len, tgt_len] @ [tgt_len, d_model]
## ------------------ 交叉注意力机制
# 定义线性变换层用于计算Q、K、V
q_proj2 = nn.Linear(d_model, d_model)
k_proj2 = nn.Linear(d_model, d_model)
v_proj2 = nn.Linear(d_model, d_model)
# 计算Q、K、V,其中K和V来自编码器的记忆
q, k, v = q_proj2(output), k_proj2(memory), v_proj2(memory) # q: [tgt_len, d_model], k,v: [src_len, d_model]
# 计算注意力权重
weight = q @ k.transpose(0, 1) # [tgt_len, src_len]
# 应用源序列的填充掩码
weight = weight.masked_fill(src_key_padding_mask, float('-inf'))
weight = torch.softmax(weight, dim=-1)
output = weight @ v # [tgt_len, d_model]
# =================== 损失计算 ===================
# 定义目标标签和损失函数
tgt = torch.tensor([1, 2, 3, 0, 0, 0], dtype=torch.long)
loss_fcn = torch.nn.CrossEntropyLoss(ignore_index=0) # 忽略填充标记(这里为0)的损失计算
loss = loss_fcn(output, tgt)
print(loss)
完整transformer代码
import torch
import torch.nn as nn
import math
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
# -------------------- 1. 数据集生成 --------------------
class MyDataset(Dataset):
def __init__(self, num_samples=1000, max_len=10, vocab_size=20):
"""
生成随机长度的数字序列数据集
关键点:
- 序列长度随机(模拟真实场景的变长序列)
- 0被保留为填充符号,因此有效token从1开始
"""
self.data = []
for _ in range(num_samples):
# 生成3到max_len之间的随机长度
seq_len = torch.randint(3, max_len, (1,)).item()
# 生成序列内容(注意:从1开始,0留给填充)
sequence = torch.randint(1, vocab_size, (seq_len,))
self.data.append(sequence)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# -------------------- 2. 数据处理与填充 --------------------
def collate_fn(batch):
"""
批处理函数,主要完成:
1. 动态填充:将不同长度的序列填充到相同长度
2. 生成目标序列(输入右移一位)
3. 注意:实际任务中需要根据任务类型调整目标生成方式
"""
# 对输入进行填充(batch_first=True表示batch维度在前)
# padding_value=0 表示用0填充短序列
padded_inputs = pad_sequence(
batch,
batch_first=True,
padding_value=0
)
# 生成目标序列(将输入序列右移一位,末尾补0)
# 例如:输入 [1,2,3] → 目标 [2,3,0]
targets = pad_sequence(
[torch.cat([seq[1:], torch.tensor([0])]) for seq in batch],
batch_first=True,
padding_value=0
)
return padded_inputs, targets
# -------------------- 3. 位置编码 --------------------
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
"""
位置编码实现:
- 使用正弦和余弦函数生成位置编码
- 公式:PE(pos,2i) = sin(pos/10000^(2i/d_model))
PE(pos,2i+1) = cos(pos/10000^(2i/d_model))
"""
super().__init__()
# 创建位置编码矩阵 (max_len, d_model)
pe = torch.zeros(max_len, d_model)
# 生成位置序列 (0到max_len-1)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算div_term用于调整正弦波频率
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
# 填充位置编码矩阵
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置用cos
# 扩展维度:(1, max_len, d_model) 用于后续广播
pe = pe.unsqueeze(0)
# 注册为缓冲区(不参与训练,但会保存到模型参数中)
self.register_buffer('pe', pe)
def forward(self, x):
"""
将位置编码添加到输入嵌入中
x的形状:(batch_size, seq_len, d_model)
"""
# self.pe[:, :x.size(1), :] 自动广播到batch维度
x = x + self.pe[:, :x.size(1), :]
return x
# -------------------- 4. Transformer模型 --------------------
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2):
"""
Transformer模型组成:
- 嵌入层:将离散token转换为连续向量
- 位置编码:添加序列位置信息
- Transformer核心层:处理序列关系
- 全连接层:输出词表概率分布
"""
super().__init__()
# 嵌入层(注意设置padding_idx=0)
self.embedding = nn.Embedding(
vocab_size,
d_model,
padding_idx=0 # 重要!使填充位置的嵌入向量为全零
)
# 位置编码
self.pos_encoder = PositionalEncoding(d_model)
# Transformer核心
self.transformer = Transformer(
d_model=d_model,
nhead=nhead, # 注意力头数(需要能被d_model整除)
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
batch_first=True # 使用(batch, seq, feature)格式
)
# 输出层
self.fc = nn.Linear(d_model, vocab_size)
def forward(self, src, src_key_padding_mask):
"""
前向传播过程:
1. 嵌入层 → 2. 缩放 → 3. 位置编码 → 4. Transformer处理 → 5. 输出投影
注意:本例简化了decoder输入,实际任务需要区分encoder/decoder输入
"""
# 步骤1:词嵌入(将整数索引转换为向量)
# src形状:(batch_size, seq_len) → (batch_size, seq_len, d_model)
src_emb = self.embedding(src)
# 步骤2:缩放嵌入向量(根据Transformer论文的建议)
src_emb = src_emb * math.sqrt(self.embedding.embedding_dim)
# 步骤3:添加位置编码
src_emb = self.pos_encoder(src_emb)
# 步骤4:通过Transformer(这里简化为使用相同输入作为encoder/decoder输入)
# 注意:实际任务中decoder应有不同的输入(如右移的目标序列)
output = self.transformer(
src_emb, # encoder输入
src_emb, # decoder输入(本例简化处理)
src_key_padding_mask=src_key_padding_mask, # 屏蔽encoder填充位置
tgt_key_padding_mask=src_key_padding_mask # 屏蔽decoder填充位置
)
# 步骤5:投影到词表空间
# 输出形状:(batch_size, seq_len, vocab_size)
return self.fc(output)
# -------------------- 5. 训练配置 --------------------
vocab_size = 20 # 词表大小(含填充符0)
batch_size = 32 # 批大小
d_model = 128 # 模型维度
epochs = 10 # 训练轮数
# 创建数据加载器
dataset = MyDataset()
dataloader = DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn, # 使用自定义的批处理函数
shuffle=True
)
# 初始化模型和优化器
model = TransformerModel(vocab_size)
# 使用交叉熵损失,忽略填充位置(ignore_index=0)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# -------------------- 6. 训练循环 --------------------
for epoch in range(epochs):
total_loss = 0
for inputs, targets in dataloader:
"""
训练步骤详解:
1. 生成padding mask
2. 前向传播
3. 计算损失(忽略填充位置)
4. 反向传播
"""
# 生成padding mask(True表示需要屏蔽的填充位置)
# 原理:找出inputs中等于0的位置(即填充位置)
mask = (inputs == 0)
# 前向传播
outputs = model(inputs, src_key_padding_mask=mask)
# 计算损失
# 将输出展平为(batch_size*seq_len, vocab_size)
logits = outputs.view(-1, vocab_size)
# 将目标展平为(batch_size*seq_len)
labels = targets.view(-1)
# 计算交叉熵损失(自动忽略labels为0的位置)
loss = criterion(logits, labels)
# 反向传播
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新参数
total_loss += loss.item()
# 打印每轮平均损失
avg_loss = total_loss / len(dataloader)
print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")
print("训练完成!")
"""
关键概念总结:
1. 动态填充:每个batch独立填充到当前batch的最大长度,节省内存
2. Padding Mask:
- 形状:(batch_size, seq_len)
- True表示对应位置是填充,需要被注意力机制忽略
3. 位置编码:为每个位置生成唯一编码,使模型感知序列顺序
4. 注意力屏蔽:防止模型关注填充位置和未来信息(解码器)
5. 损失忽略:计算损失时跳过填充位置,避免影响参数更新
"""