PyTorch 深度学习实战(8):Transformer 与机器翻译(基于本地中英文文本文件)

在上一篇文章中,我们探讨了循环神经网络(RNN)及其在文本分类中的应用。本文将介绍 Transformer 模型的基本原理,并使用 PyTorch 和 Hugging Face 的 transformers 库实现一个简单的机器翻译模型。我们将基于本地的中英文文本文件进行实战演练。


一、Transformer 基础

Transformer 是一种基于自注意力机制(Self-Attention)的神经网络架构,由 Vaswani 等人在 2017 年提出。它在自然语言处理(NLP)任务中表现出色,尤其是在机器翻译领域。

1. Transformer 的结构

Transformer 的核心组件包括:

  • 编码器(Encoder):将输入序列(如中文句子)转换为一系列隐藏表示。

  • 解码器(Decoder):基于编码器的输出和已生成的目标序列(如英文句子)生成下一个词。

  • 自注意力机制(Self-Attention):捕捉序列中每个词与其他词的关系。

  • 位置编码(Positional Encoding):为模型提供序列中词的位置信息。

2. Transformer 的优势
  • 并行计算:与 RNN 不同,Transformer 可以并行处理整个序列,训练速度更快。

  • 长距离依赖:自注意力机制能够更好地捕捉长距离依赖关系。

  • 可扩展性:Transformer 可以通过堆叠更多的层来提升性能。

3. 机器翻译任务

机器翻译是将一种语言的文本自动翻译为另一种语言的任务。例如,将中文翻译为英文。


二、机器翻译实战

我们将使用本地的中英文文本文件来训练一个 Transformer 模型,实现中英机器翻译任务。

1. 问题描述

我们有两个文本文件

  • chinese.zh:包含中文句子。

  • english.en:包含对应的英文句子。

我们的目标是构建一个 Transformer 模型,能够将中文句子翻译为英文句子。

2. 实现步骤
  1. 加载和预处理数据。

  2. 构建词汇表并将文本转换为索引序列。

  3. 定义 Transformer 模型。

  4. 定义损失函数和优化器。

  5. 训练模型。

  6. 测试模型并评估性能。

3. 代码实现
import re
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import EncoderDecoderModel, BertConfig, AdamW, BertModel, BertLMHeadModel
​
# 数据预处理和词汇表构建
def build_vocab(lines, is_chinese=True, min_freq=1):
    vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3}
    word_counts = {}
    for line in lines:
        line = line.strip()
        tokens = list(line) if is_chinese else re.findall(r"[\w']+|[^a-zA-Z\s]", line.lower())
        for token in tokens:
            word_counts[token] = word_counts.get(token, 0) + 1
​
    idx = 4
    for word, count in word_counts.items():
        if count >= min_freq and word not in vocab:
            vocab[word] = idx
            idx += 1
    return vocab
​
# 数据集类
class TranslationDataset(Dataset):
    def __init__(self, src_lines, tgt_lines, src_vocab, tgt_vocab, max_len=50):
        self.src_lines = src_lines
        self.tgt_lines = tgt_lines
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.max_len = max_len
​
    def __len__(self):
        return len(self.src_lines)
​
    def __getitem__(self, idx):
        src_line = self.src_lines[idx].strip()
        tgt_line = self.tgt_lines[idx].strip()
​
        # 处理源语言(中文)
        src_tokens = list(src_line)[:self.max_len]
        src_ids = [self.src_vocab.get(t, 3) for t in src_tokens]
        src_ids = [1] + src_ids + [2]  # 添加<sos>和<eos>
​
        # 处理目标语言(英文)
        tgt_tokens = re.findall(r"[\w']+|[^a-zA-Z\s]", tgt_line.lower())[:self.max_len]
        tgt_ids = [self.tgt_vocab.get(t, 3) for t in tgt_tokens]
        tgt_ids = [1] + tgt_ids + [2]
​
        return {
            'input_ids': torch.tensor(src_ids, dtype=torch.long),
            'labels': torch.tensor(tgt_ids, dtype=torch.long)
        }
​
# 数据整理函数
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    labels = [item['labels'] for item in batch]
​
    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=0)
​
    return {
        'input_ids': input_ids,
        'labels': labels,
        'attention_mask': (input_ids != 0).long()
    }
​
# 初始化Transformer模型
def build_model(src_vocab_size, tgt_vocab_size, device):
    # 初始化编码器配置
    encoder_config = BertConfig(
        vocab_size=src_vocab_size,
        hidden_size=256,
        num_hidden_layers=2,
        num_attention_heads=8,
        intermediate_size=512,
        hidden_dropout_prob=0.3,
        attention_probs_dropout_prob=0.3
    )
​
    # 初始化解码器配置
    decoder_config = BertConfig(
        vocab_size=tgt_vocab_size,
        hidden_size=256,
        num_hidden_layers=2,
        num_attention_heads=8,
        intermediate_size=512,
        hidden_dropout_prob=0.3,
        attention_probs_dropout_prob=0.3,
        is_decoder=True,
        add_cross_attention=True
    )
​
    # 创建实际模型实例
    encoder = BertModel(encoder_config)
    decoder = BertLMHeadModel(decoder_config)
​
    # 组装Encoder-Decoder结构
    model = EncoderDecoderModel(encoder=encoder, decoder=decoder)
​
    # 关键修复:设置解码器起始标记
    model.config.decoder_start_token_id = 1  # 对应<sos>的ID
    model.config.pad_token_id = 0  # 确保pad token设置正确
    model.config.eos_token_id = 2  # 确保eos token设置正确
​
    # 初始化权重并转移设备
    model.init_weights()
    return model.to(device)
​
# 训练函数
def train(model, dataloader, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in dataloader:
            inputs = {
                'input_ids': batch['input_ids'].to(device),
                'attention_mask': batch['attention_mask'].to(device),
                'labels': batch['labels'].to(device)
            }
​
            optimizer.zero_grad()
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
​
            total_loss += loss.item()
​
        print(f"Epoch {epoch + 1} | Loss: {total_loss / len(dataloader):.4f}")
​
# 翻译函数
def translate(model, sentence, src_vocab, tgt_vocab_inv, device, max_length=50):
    model.eval()
    tokens = list(sentence.strip())
    src_ids = [src_vocab.get(t, 3) for t in tokens]
    src_ids = [1] + src_ids + [2]
    input_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)  # Add batch dimension
​
    generated = model.generate(
        input_ids=input_tensor,
        max_length=max_length,
        decoder_start_token_id=1,  # <sos>
        eos_token_id=2,  # <eos>
        pad_token_id=0,  # <pad>
        num_beams=5,
        early_stopping=True
    )
​
    output_ids = generated[0].cpu().numpy()
    translated = [tgt_vocab_inv.get(idx, '<unk>') for idx in output_ids if idx not in [0, 1, 2]]
    return ''.join(translated) if len(src_vocab) > 5000 else ' '.join(translated)
​
# 主程序
if __name__ == "__main__":
    # 配置参数
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    BATCH_SIZE = 32
    EPOCHS = 20
​
    # 加载数据
    with open('chinese.zh', 'r', encoding='utf-8') as f:
        ch_lines = f.readlines()
    with open('english.en', 'r', encoding='utf-8') as f:
        en_lines = f.readlines()
​
    # 构建词汇表
    src_vocab = build_vocab(ch_lines, is_chinese=True)
    tgt_vocab = build_vocab(en_lines, is_chinese=False)
    tgt_vocab_inv = {v: k for k, v in tgt_vocab.items()}
​
    # 准备数据集
    dataset = TranslationDataset(ch_lines, en_lines, src_vocab, tgt_vocab)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
​
    # 初始化模型
    model = build_model(len(src_vocab), len(tgt_vocab), device)
    optimizer = AdamW(model.parameters(), lr=5e-4)
​
    # 训练模型
    train(model, dataloader, optimizer, device, epochs=EPOCHS)
​
    # 测试翻译
    test_sentences = [
        "你好,世界!",
        "今天天气真好",
        "人工智能改变世界"
    ]
​
    for sent in test_sentences:
        translation = translate(model, sent, src_vocab, tgt_vocab_inv, device)
        print(f"中文: {sent}")
        print(f"翻译: {translation}\n")

三、代码解析

  1. 数据加载与预处理

    • 使用 jieba 进行中文分词。

    • 构建词汇表并将文本转换为索引序列。

    • 使用 TranslationDataset 类封装数据,支持填充和截断。

  2. Transformer 模型

    • 使用 BertModel 作为编码器,BertLMHeadModel 作为解码器。

    • 通过 EncoderDecoderModel 组装编码器和解码器。

  3. 训练过程

    • 使用交叉熵损失函数和 AdamW 优化器。

    • 训练 20 个 epoch,并记录损失值。

  4. 测试过程

    • 使用 translate 函数对测试句子进行翻译。


四、运行结果

运行上述代码后,你将看到以下输出:

  • 训练过程中每 epoch 打印一次损失值。

  • 测试句子的翻译结果。


五、改进建议

在测试过程中,我们发现中文翻译结果并不理想。以下是改进建议:

  1. 增加数据量:更多的训练数据可以提升模型的泛化能力。

  2. 调整模型结构:增加编码器和解码器的层数或隐藏层大小。

  3. 使用预训练模型:例如,使用 Hugging Face 的 mT5mBART 模型。

  4. 调整超参数:例如,学习率、批次大小、训练轮数等。

  5. 数据增强:对训练数据进行同义词替换、随机删除等操作。


六、总结

本文介绍了 Transformer 的基本原理,并使用 PyTorch 实现了一个简单的中英机器翻译模型。通过这个例子,我们学习了如何处理中英文数据、构建 Transformer 模型以及进行训练和评估。

在下一篇文章中,我们将探讨时间序列预测与 LSTM 模型。敬请期待!


代码实例说明

  • 本文代码可以直接在 Jupyter Notebook 或 Python 脚本中运行。

  • 如果你有 GPU,可以将模型和数据移动到 GPU 上运行,例如:model = model.to('cuda')input_tensor = input_tensor.to('cuda')

希望这篇文章能帮助你更好地理解 Transformer 及其在机器翻译中的应用!如果有任何问题,欢迎在评论区留言讨论。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值