Transformer模型中的两种掩码

模型训练通常使用 批处理batch来提升训练效率。而实际中Transformer的输入序列(如句子、文本片段)长度往往不一致。为了让这些样本可以组成一个统一的形状 [B, T] 的张量给GPU进行并行计算提高效率,需要将较短的序列填充(pad)到最大长度,比如统一为 T=10。填充内容一般是特殊的 <PAD> 符号,或者对应的 token ID(如 0)。

Transformer在训练和推理时,padding的作用不同。训练阶段,为并行处理多个长度不一的序列,需用padding统一长度,确保模型能一次性处理整个批次。推理阶段,单序列生成时无需padding,模型逐token输出;但若采用批处理推理,仍需padding对齐不同长度的序列。因此,训练时padding必不可少,推理时视处理方式而定。

整个 Transformer 结构中涉及到的 “掩码” 类型一共有两种:① 用于区分同一个 batch 中不同长度序列是否被填充的 key padding mask;② 在训练时,Decoder 中用于模仿推理过程中在编码当前时禁止看到未来信息的 attention mask(也叫做 casual mask 或 future mask)

对于自回归模型来说,我们在训练阶段为了加速模型的训练过程,都是直接使用小批量样本来进行训练,而这就导致了不同样本长度不能进行计算的问题,因而要使用填充操作来使同一个 batch 中的序列长度保持一致。同时,在自注意力机制中,为了防止注意力被分配到 padding 位置上,所以我们需要通过 padding mask 来忽略掉这部分注意力值。

名称核心作用应用模块是否与batch相关典型形状与值示例使用场景
key_padding_mask屏蔽无效的padding符号(如句末填充的占位符)Encoder层、Decoder层的自注意力与交叉注意力✅ 是(每个样本独立)[B, T]
例:[0,0,-inf]
处理变长序列时过滤padding
causal mask
(attention_mask)
防止模型看到未来时刻的信息(保持时序依赖)Decoder层的自注意力❌ 否(全局共享)[T, T]
例:上三角矩阵填充-inf
文本生成、自回归预测任务
  • ​​Query​​:代表当前生成任务的需求(如"接下来该生成什么词?"),通过当前隐藏状态线性变换生成
  • ​​Key​​:作为历史上下文的索引(如已生成的词语),用于与Query匹配相关性
  • Value​​:存储实际语义内容,根据注意力权重提取关键信息
  • 动态信息检索系统:Q提出问题 → K响应线索 → V提供语义内容
  • 注意力权重计算:通过Q与K的点积相似度 → 缩放 → softmax归一化 → 加权聚合V
  • 线性变换​​:每个输入位置的词向量通过独立权重矩阵生成Q/K/V
  • 缩放点积​​:为防止梯度消失,需除以√d_k(d_k为向量维度)
  • 多头注意力​​:分割为多个子空间(如8头),分别捕捉语法/语义等多维度关系
  • 位置编码​​:通过正弦/余弦函数注入位置信息,保持序列顺序感知
  • 自回归生成:通过因果掩码确保仅关注历史信息
  • 注意力聚焦:生成"这里"时,Q高匹配"公众号"的K,提取其V的语义特征
  • 概率预测:加权后的上下文向量经FFN输出词表概率分布
import torch
import torch.nn.functional as F

# ============== 输入设置 ==============
# 两个示例序列:
# 序列1: [1, 2, 0] (0是填充符号)
# 序列2: [3, 4, 5] (没有填充)
inputs = torch.tensor([[1, 2, 0], [3, 4, 5]])
print("输入序列:")
print(inputs)

# ============== 1. 填充掩码 ==============
print("\n== 第一部分: 填充掩码 ==")

# 填充掩码标识哪些位置包含填充符(0)
# True = 这个位置应被忽略
# False = 这个位置包含真实数据
key_padding_mask = (inputs == 0)
print("\n填充掩码 (True表示'忽略此位置'):")
print(key_padding_mask)
print("解释: 第一个序列的最后一位是填充符,所以标记为True")

# ============== 2. 因果掩码 ==============
print("\n== 第二部分: 因果掩码 ==")

# 因果掩码防止token关注未来位置
# 这对GPT等自回归模型至关重要
seq_len = inputs.size(1)

# 创建下三角矩阵(包括对角线),1表示可见,0表示被掩盖
causal_mask = torch.tril(torch.ones(seq_len, seq_len))
# 转换为布尔掩码(0->True表示需掩盖,1->False表示可见)
causal_mask = causal_mask == 0
print("\n因果掩码 (True表示'需掩盖'):")
print(causal_mask)
print("解释: 每个位置只能看到自己和之前的位置")

# ============== 3. 应用掩码到注意力分数 ==============
print("\n== 第三部分: 应用掩码 ==")

# 模拟注意力分数
attn_scores = torch.randn(2, 3, 3)
print("\n原始注意力分数 (随机):")
print(attn_scores)

# 应用填充掩码
attn_scores = attn_scores.masked_fill(
    key_padding_mask.unsqueeze(1),  # [2,3] -> [2,1,3]
    -1e9  # 非常小的值,softmax后接近0
)

# 应用因果掩码
attn_scores = attn_scores.masked_fill(
    causal_mask.unsqueeze(0),  # [3,3] -> [1,3,3]
    -1e9
)

# 计算最终注意力权重
attn_weights = F.softmax(attn_scores, dim=-1)
print("\n最终注意力权重:")
print(attn_weights)
print("解释: 权重已被两种掩码调整,填充位置和未来位置的权重接近0")

Transformer 中 Decoder 里的 tgt key padding mask 操作其实是不需要的。因为计算损失时只会考虑非padding的字符,padding字符不参与损失计算。参考:Transformer中Decoder真的不需要PaddingMask?QKV到底是怎么来的?

注意力机制和mask掩码示例代码

import torch
import torch.nn as nn

# 设置随机种子,确保实验结果可复现
torch.random.manual_seed(42)

# 定义模型维度和序列长度
d_model = 7  # 模型的维度
tgt_len = 6  # 目标序列长度
src_len = 5  # 源序列长度

# =================== 编码器部分 ===================
# 随机生成源序列的嵌入表示
src_em_input = torch.randn(src_len, d_model)
# 定义线性变换层用于计算Q、K、V
q_proj0 = nn.Linear(d_model, d_model)
k_proj0 = nn.Linear(d_model, d_model)
v_proj0 = nn.Linear(d_model, d_model)
# 计算Q、K、V
q, k, v = q_proj0(src_em_input), k_proj0(src_em_input), v_proj0(src_em_input)
# 创建源序列的填充掩码
src_key_padding_mask = torch.tensor([[False, False, False, True, True]])
# 计算注意力权重
weight = q @ k.transpose(0, 1)  # [src_len, src_len]
print("Encoder attention weight before key padding:")
print(weight)
# 应用填充掩码
weight = weight.masked_fill(src_key_padding_mask, float('-inf'))
print("Encoder attention weight after key padding:")
print(weight)
# 计算softmax值
weight = torch.softmax(weight, dim=-1)  # [src_len, src_len]
# 得到加权后的输出(记忆)
memory = weight @ v  # [src_len, d_model]

# =================== 解码器部分 ===================
## ------------------ 带遮罩的自注意力机制
# 随机生成目标序列的嵌入表示
tgt_em_input = torch.randn(tgt_len, d_model)
# 创建目标序列的填充掩码
tgt_key_padding_mask = torch.tensor([[False, False, False, True, True, True]])
# 创建上三角矩阵作为遮罩,防止未来信息泄露
attn_mask = (torch.triu(torch.ones((tgt_len, tgt_len))) == 1).transpose(0, 1)
attn_mask = attn_mask.float().masked_fill(attn_mask == 0, float('-inf')).masked_fill(attn_mask == 1, float(0.))
# 定义线性变换层用于计算Q、K、V
q_proj1 = nn.Linear(d_model, d_model)
k_proj1 = nn.Linear(d_model, d_model)
v_proj1 = nn.Linear(d_model, d_model)
# 计算Q、K、V
q, k, v = q_proj1(tgt_em_input), k_proj1(tgt_em_input), v_proj1(tgt_em_input)  # [tgt_len, d_model]
# 计算注意力权重
weight = q @ k.transpose(0, 1)  # [tgt_len, tgt_len]
weight += attn_mask
print("Decoder masked attention weight before key padding:")
print(weight)



# 这里控制是否应用填充掩码
weight = weight.masked_fill(tgt_key_padding_mask, float('-inf'))
print("Decoder masked attention weight after key padding:")
print(weight)

weight = torch.softmax(weight, dim=-1)  # [tgt_len, tgt_len]
output = weight @ v  # [tgt_len, tgt_len] @ [tgt_len, d_model]



## ------------------ 交叉注意力机制
# 定义线性变换层用于计算Q、K、V
q_proj2 = nn.Linear(d_model, d_model)
k_proj2 = nn.Linear(d_model, d_model)
v_proj2 = nn.Linear(d_model, d_model)
# 计算Q、K、V,其中K和V来自编码器的记忆
q, k, v = q_proj2(output), k_proj2(memory), v_proj2(memory)  # q: [tgt_len, d_model], k,v: [src_len, d_model]
# 计算注意力权重
weight = q @ k.transpose(0, 1)  # [tgt_len, src_len]
# 应用源序列的填充掩码
weight = weight.masked_fill(src_key_padding_mask, float('-inf'))
weight = torch.softmax(weight, dim=-1)
output = weight @ v  # [tgt_len, d_model]

# =================== 损失计算 ===================
# 定义目标标签和损失函数
tgt = torch.tensor([1, 2, 3, 0, 0, 0], dtype=torch.long)
loss_fcn = torch.nn.CrossEntropyLoss(ignore_index=0)  # 忽略填充标记(这里为0)的损失计算
loss = loss_fcn(output, tgt)
print(loss)

完整transformer代码

import torch
import torch.nn as nn
import math
from torch.nn import Transformer
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# -------------------- 1. 数据集生成 --------------------
class MyDataset(Dataset):
    def __init__(self, num_samples=1000, max_len=10, vocab_size=20):
        """
        生成随机长度的数字序列数据集
        关键点:
        - 序列长度随机(模拟真实场景的变长序列)
        - 0被保留为填充符号,因此有效token从1开始
        """
        self.data = []
        for _ in range(num_samples):
            # 生成3到max_len之间的随机长度
            seq_len = torch.randint(3, max_len, (1,)).item()
            # 生成序列内容(注意:从1开始,0留给填充)
            sequence = torch.randint(1, vocab_size, (seq_len,))
            self.data.append(sequence)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# -------------------- 2. 数据处理与填充 --------------------
def collate_fn(batch):
    """
    批处理函数,主要完成:
    1. 动态填充:将不同长度的序列填充到相同长度
    2. 生成目标序列(输入右移一位)
    3. 注意:实际任务中需要根据任务类型调整目标生成方式
    """
    # 对输入进行填充(batch_first=True表示batch维度在前)
    # padding_value=0 表示用0填充短序列
    padded_inputs = pad_sequence(
        batch, 
        batch_first=True, 
        padding_value=0
    )
    
    # 生成目标序列(将输入序列右移一位,末尾补0)
    # 例如:输入 [1,2,3] → 目标 [2,3,0]
    targets = pad_sequence(
        [torch.cat([seq[1:], torch.tensor([0])]) for seq in batch],
        batch_first=True, 
        padding_value=0
    )
    
    return padded_inputs, targets

# -------------------- 3. 位置编码 --------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        """
        位置编码实现:
        - 使用正弦和余弦函数生成位置编码
        - 公式:PE(pos,2i) = sin(pos/10000^(2i/d_model))
                 PE(pos,2i+1) = cos(pos/10000^(2i/d_model))
        """
        super().__init__()
        # 创建位置编码矩阵 (max_len, d_model)
        pe = torch.zeros(max_len, d_model)
        # 生成位置序列 (0到max_len-1)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 计算div_term用于调整正弦波频率
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        # 填充位置编码矩阵
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数位置用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数位置用cos
        # 扩展维度:(1, max_len, d_model) 用于后续广播
        pe = pe.unsqueeze(0)
        # 注册为缓冲区(不参与训练,但会保存到模型参数中)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        将位置编码添加到输入嵌入中
        x的形状:(batch_size, seq_len, d_model)
        """
        # self.pe[:, :x.size(1), :] 自动广播到batch维度
        x = x + self.pe[:, :x.size(1), :]
        return x

# -------------------- 4. Transformer模型 --------------------
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=128, nhead=4, num_layers=2):
        """
        Transformer模型组成:
        - 嵌入层:将离散token转换为连续向量
        - 位置编码:添加序列位置信息
        - Transformer核心层:处理序列关系
        - 全连接层:输出词表概率分布
        """
        super().__init__()
        # 嵌入层(注意设置padding_idx=0)
        self.embedding = nn.Embedding(
            vocab_size, 
            d_model, 
            padding_idx=0  # 重要!使填充位置的嵌入向量为全零
        )
        # 位置编码
        self.pos_encoder = PositionalEncoding(d_model)
        # Transformer核心
        self.transformer = Transformer(
            d_model=d_model,
            nhead=nhead,          # 注意力头数(需要能被d_model整除)
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            batch_first=True      # 使用(batch, seq, feature)格式
        )
        # 输出层
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src, src_key_padding_mask):
        """
        前向传播过程:
        1. 嵌入层 → 2. 缩放 → 3. 位置编码 → 4. Transformer处理 → 5. 输出投影
        注意:本例简化了decoder输入,实际任务需要区分encoder/decoder输入
        """
        # 步骤1:词嵌入(将整数索引转换为向量)
        # src形状:(batch_size, seq_len) → (batch_size, seq_len, d_model)
        src_emb = self.embedding(src)
        
        # 步骤2:缩放嵌入向量(根据Transformer论文的建议)
        src_emb = src_emb * math.sqrt(self.embedding.embedding_dim)
        
        # 步骤3:添加位置编码
        src_emb = self.pos_encoder(src_emb)
        
        # 步骤4:通过Transformer(这里简化为使用相同输入作为encoder/decoder输入)
        # 注意:实际任务中decoder应有不同的输入(如右移的目标序列)
        output = self.transformer(
            src_emb,            # encoder输入
            src_emb,            # decoder输入(本例简化处理)
            src_key_padding_mask=src_key_padding_mask,  # 屏蔽encoder填充位置
            tgt_key_padding_mask=src_key_padding_mask   # 屏蔽decoder填充位置
        )
        
        # 步骤5:投影到词表空间
        # 输出形状:(batch_size, seq_len, vocab_size)
        return self.fc(output)

# -------------------- 5. 训练配置 --------------------
vocab_size = 20    # 词表大小(含填充符0)
batch_size = 32    # 批大小
d_model = 128      # 模型维度
epochs = 10        # 训练轮数

# 创建数据加载器
dataset = MyDataset()
dataloader = DataLoader(
    dataset, 
    batch_size=batch_size, 
    collate_fn=collate_fn,  # 使用自定义的批处理函数
    shuffle=True
)

# 初始化模型和优化器
model = TransformerModel(vocab_size)
# 使用交叉熵损失,忽略填充位置(ignore_index=0)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# -------------------- 6. 训练循环 --------------------
for epoch in range(epochs):
    total_loss = 0
    for inputs, targets in dataloader:
        """
        训练步骤详解:
        1. 生成padding mask
        2. 前向传播
        3. 计算损失(忽略填充位置)
        4. 反向传播
        """
        # 生成padding mask(True表示需要屏蔽的填充位置)
        # 原理:找出inputs中等于0的位置(即填充位置)
        mask = (inputs == 0)
        
        # 前向传播
        outputs = model(inputs, src_key_padding_mask=mask)
        
        # 计算损失
        # 将输出展平为(batch_size*seq_len, vocab_size)
        logits = outputs.view(-1, vocab_size)
        # 将目标展平为(batch_size*seq_len)
        labels = targets.view(-1)
        # 计算交叉熵损失(自动忽略labels为0的位置)
        loss = criterion(logits, labels)
        
        # 反向传播
        optimizer.zero_grad()  # 清空梯度
        loss.backward()        # 计算梯度
        optimizer.step()       # 更新参数
        
        total_loss += loss.item()
    
    # 打印每轮平均损失
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f}")

print("训练完成!")

"""
关键概念总结:
1. 动态填充:每个batch独立填充到当前batch的最大长度,节省内存
2. Padding Mask:
   - 形状:(batch_size, seq_len)
   - True表示对应位置是填充,需要被注意力机制忽略
3. 位置编码:为每个位置生成唯一编码,使模型感知序列顺序
4. 注意力屏蔽:防止模型关注填充位置和未来信息(解码器)
5. 损失忽略:计算损失时跳过填充位置,避免影响参数更新
"""
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值