手撕Transformer系列(06):封装 Transformer,英德翻译实战

本篇是「手撕 Transformer」系列的第 6 篇,我们将整合之前实现的模块,构建完整的 Transformer 模型,并应用于英 ➡️ 德翻译任务。包含完整的 model 封装训练与推理代码翻译示例训练结果分析

🧱 Part 1:封装 Transformer 总结构

✅ Transformer 模型结构

我们将使用 Encoder-Decoder 架构构建如下模型:

import torch
import torch.nn as nn
import sys
import os

# 添加项目根目录到 sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')))

from models.model.Encoder import Encoder
from models.model.Decoder import Decoder

class Transformer(nn.Module):
    """
    Transformer模型的封装:
        1. Encoder和Decoder的组合
        2. 采用残差连接和LayerNorm
        3. 前馈网络由两个线性变换和一个ReLU激活函数组成
        4. 多头自注意力机制由多个头组成,每个头都有自己的线性变换和缩放点积注意力计算
        5. 多头自注意力机制的输出经过线性变换和dropout
        6. 需要把tokenizer部分的token处理也要考虑进入的!!!!!
    """
    def __init__(self, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size, max_len, d_model, n_heads,
                 ffn_hidden, n_layers, dropout, device):
        super(Transformer, self).__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_sos_idx = trg_sos_idx
        self.device = device
        self.encoder = Encoder(enc_voc_size=enc_voc_size,
                               max_len=max_len,
                               d_model=d_model,
                               ffn_hidden=ffn_hidden,
                               n_heads=n_heads,
                               n_layers=n_layers,
                               drop_prob=dropout,
                               device=device)
        
        self.decoder = Decoder(dec_voc_size=dec_voc_size,
                               max_len=max_len,
                               d_model=d_model,
                               ffn_hidden=ffn_hidden,
                               n_heads=n_heads,
                               n_layers=n_layers,
                               dropout=dropout,
                               device=device)
    
    def forward(self, src, trg):
        # src: [batch_size, src_len]
        # trg: [batch_size, trg_len]
        # 1. Encoder
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)

        src_trg_mask = self.make_pad_mask(trg, src, self.trg_pad_idx, self.src_pad_idx)

        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx) * \
                   self.make_no_peak_mask(trg, trg)
        
        enc_src = self.encoder(src, src_mask)

        output = self.decoder(trg, enc_src, trg_mask, src_trg_mask)

        return output
        
    def make_pad_mask(self, q, k, q_pad_idx, k_pad_idx):
        len_q, len_k = q.size(1), k.size(1)

        # batch_size x 1 x 1 x len_k
        k = k.ne(k_pad_idx).unsqueeze(1).unsqueeze(2)
        # batch_size x 1 x len_q x len_k
        k = k.repeat(1, 1, len_q, 1)

        # batch_size x 1 x len_q x 1
        q = q.ne(q_pad_idx).unsqueeze(1).unsqueeze(3)
        # batch_size x 1 x len_q x len_k
        q = q.repeat(1, 1, 1, len_k)

        mask = k & q
        return mask

    def make_no_peak_mask(self, q, k):
        len_q, len_k = q.size(1), k.size(1)

        # len_q x len_k
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.BoolTensor).to(self.device)

        return mask
    

🛠️ Part 2:训练脚本 train.py

我们使用 torchtext 加载英德数据集(如 Multi30k),并编写训练流程。

🔁 训练核心流程

"""
Transformers for Translation Task


"""
import math
import time
import torch

from torch import nn, optim

from data import *
from models.model.Transformer import Transformer
from util.bleu import idx_to_word, get_bleu
from util.epoch_time import epoch_time

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
    elif hasattr(m, 'bias') and m.bias is not None:
        nn.init.constant_(m.bias.data, 0)


model = Transformer(src_pad_idx=src_pad_idx,
                    trg_pad_idx=trg_pad_idx,
                    trg_sos_idx=trg_sos_idx,
                    enc_voc_size=enc_voc_size,
                    dec_voc_size=dec_voc_size,
                    max_len=max_len,
                    d_model=d_model,
                    n_heads=n_heads,
                    ffn_hidden=ffn_hidden,
                    n_layers=n_layers,
                    dropout=drop_prob,
                    device=device).to(device)

print(f'The model has {count_parameters(model):,} trainable parameters')
model.apply(initialize_weights)

optimizer = optim.Adam(model.parameters(), 
                       lr=init_lr, 
                       weight_decay=weight_decay,
                       eps =adam_eps)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                                                 verbose=True,
                                                 factor=factor,
                                                 patience=patience)

criterion = nn.CrossEntropyLoss(ignore_index=trg_pad_idx).to(device)

def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):
        src = batch.src
        trg = batch.trg

        optimizer.zero_grad()
        output = model(src, trg[:, :-1])
        output_reshape = output.contiguous().view(-1, output.shape[-1])
        trg = trg[:, 1:].contiguous().view(-1)

        loss = criterion(output_reshape, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()
        print('step :', round((i / len(iterator)) * 100, 2), '% , loss :', loss.item())

    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    batch_bleu = []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            src = batch.src
            trg = batch.trg
            output = model(src, trg[:, :-1])
            output_reshape = output.contiguous().view(-1, output.shape[-1])
            trg = trg[:, 1:].contiguous().view(-1)

            loss = criterion(output_reshape, trg)
            epoch_loss += loss.item()

            total_bleu = []
            for j in range(batch_size):
                try:
                    trg_words = idx_to_word(batch.trg[j], loader.target.vocab)
                    output_words = output[j].max(dim=1)[1]
                    output_words = idx_to_word(output_words, loader.target.vocab)
                    bleu = get_bleu(hypotheses=output_words.split(), reference=trg_words.split())
                    total_bleu.append(bleu)
                except:
                    pass

            total_bleu = sum(total_bleu) / len(total_bleu)
            batch_bleu.append(total_bleu)

    batch_bleu = sum(batch_bleu) / len(batch_bleu)
    return epoch_loss / len(iterator), batch_bleu

def run(total_epoch, best_loss):
    train_losses, test_losses, bleus = [], [], []
    for step in range(total_epoch):
        start_time = time.time()
        train_loss = train(model, train_iter, optimizer, criterion, clip)
        valid_loss, bleu = evaluate(model, valid_iter, criterion)
        end_time = time.time()

        if step > warmup:
            scheduler.step(valid_loss)

        train_losses.append(train_loss)
        test_losses.append(valid_loss)
        bleus.append(bleu)
        epoch_mins, epoch_secs = epoch_time(start_time, end_time)

        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), 'saved/model-{0}.pt'.format(valid_loss))

        f = open('result/train_loss.txt', 'w')
        f.write(str(train_losses))
        f.close()

        f = open('result/bleu.txt', 'w')
        f.write(str(bleus))
        f.close()

        f = open('result/test_loss.txt', 'w')
        f.write(str(test_losses))
        f.close()

        print(f'Epoch: {step + 1} | Time: {epoch_mins}m {epoch_secs}s')
        print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
        print(f'\tVal Loss: {valid_loss:.3f} |  Val PPL: {math.exp(valid_loss):7.3f}')
        print(f'\tBLEU Score: {bleu:.3f}')
    
if __name__ == '__main__':
    run(total_epoch=epoch, best_loss=inf)

🔍 Part 3:推理与翻译 inference.py

我们编写一个简单的 greedy decoding 函数来完成句子生成。

import torch
from models.model.Transformer import Transformer
from util.bleu import idx_to_word
from data import loader, src_pad_idx, trg_pad_idx, trg_sos_idx, enc_voc_size, dec_voc_size
from config import max_len, d_model, n_heads, ffn_hidden, n_layers, drop_prob  # 如果这些变量不是全局就 import 它们的值

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 创建 Transformer 模型实例
model = Transformer(
    src_pad_idx=src_pad_idx,
    trg_pad_idx=trg_pad_idx,
    trg_sos_idx=trg_sos_idx,
    enc_voc_size=enc_voc_size,
    dec_voc_size=dec_voc_size,
    max_len=max_len,
    d_model=d_model,
    n_heads=n_heads,
    ffn_hidden=ffn_hidden,
    n_layers=n_layers,
    dropout=drop_prob,
    device=device
).to(device)

# 加载已训练的模型参数
model.load_state_dict(torch.load("saved/model-6.488133251667023.pt", map_location=device))  # 替换为你最佳模型路径
model.eval()

def encode_sentence(sentence: str, vocab: dict, pad_idx: int, max_len: int):
    tokens = sentence.lower().strip().split()
    idxs = [vocab.get(token, vocab['<unk>']) for token in tokens]
    if len(idxs) < max_len:
        idxs += [pad_idx] * (max_len - len(idxs))
    else:
        idxs = idxs[:max_len]
    return torch.tensor(idxs).unsqueeze(0).to(device)  # shape: [1, max_len]

def make_trg_mask(trg, pad_idx=0):
    # shape: [batch_size, 1, trg_len, trg_len]
    trg_pad_mask = (trg != pad_idx).unsqueeze(1).unsqueeze(2)
    trg_len = trg.shape[1]
    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
    trg_mask = trg_pad_mask & trg_sub_mask
    return trg_mask

def translate(sentence: str):
    src_tensor = encode_sentence(sentence, loader.source.vocab, src_pad_idx, max_len)
    src_mask = (src_tensor != src_pad_idx).unsqueeze(1).unsqueeze(2)
    enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [trg_sos_idx]

    for _ in range(max_len):
        trg_tensor = torch.tensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = make_trg_mask(trg_tensor, pad_idx=trg_pad_idx)
        output = model.decoder(trg_tensor, enc_src, src_mask, trg_mask)
        pred_token = output[:, -1, :].argmax(-1).item()
        trg_indexes.append(pred_token)

        # 停止条件
        if pred_token == trg_pad_idx or pred_token == loader.target.vocab.get('<eos>', trg_pad_idx):
            break

    translated_tokens = idx_to_word(trg_indexes, loader.target.vocab)
    return translated_tokens

if __name__ == "__main__":
    print(">>> Transformer Translator Ready. Type 'quit' to exit.")
    while True:
        sentence = input(">> Enter English sentence: ").strip()
        if sentence.lower() in ['quit', 'exit']:
            break
        translation = translate(sentence)
        print("Translation:", translation)

📊 Part 4:训练与翻译结果展示

数据集:Multi30k(英德对)
模型:6-layer Transformer
训练时间:约 30min / GPU
BLEU 分数:大约 25~27(未调参)

示例输出:

  • Input (EN): A man is playing guitar on the street.

  • Output (DE): Ein Mann spielt Gitarre auf der Straße .

🧠 Part 5:Transformer 实战总结

通过这篇文章,我们完成了从零实现到实战落地的全过程:

模块已完成
Embedding 模块
Positional Encoding
Encoder 结构
Decoder 结构
Multi-head Attention
完整 Transformer 封装
翻译任务落地实战

🚀 我们不仅实现了论文中的每一个结构,还将其落地应用到一个真实的 NLP 翻译任务中,这正是科研与工程融合的典范路径。

🗂️ 项目仓库结构建议

📚 下一步你可以做什么?

  • 🎯 替换更大的数据集(如 IWSLT、WMT)

  • 🔁 加入 Beam Search 解码

  • 📈 加入 TensorBoard 监控训练过程

  • 💡 尝试 Fine-tune BERT 作为 Encoder


你准备好进入 大型模型微调结构创新实验、或者部署推理优化的世界了吗?

那么——恭喜你!你已经手撕了 Transformer!

Reference

        zyds/transformers-code: 手把手带你实战 Huggingface Transformers 课程视频同步更新在B站与YouTube

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值