Datawhale AI 夏令营 - 基于transformer和术语词典的机器翻译

1. 安装和导入依赖

!pip install -U pip setuptools wheel -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install -U 'spacy[cuda12x]' -i https://pypi.tuna.tsinghua.edu.cn/simple
!pip install ../dataset/en_core_web_trf-3.7.3-py3-none-any.whl

这些命令用于安装Python库,包括pip, setuptools, wheel, spacy以及一个本地的Spacy模型。

2. 数据预处理

from torchtext.data.utils import get_tokenizer
import jieba

# 定义tokenizer
en_tokenizer = get_tokenizer('spacy', language='en_core_web_trf')
zh_tokenizer = lambda x: list(jieba.cut(x))  # 使用jieba分词

这段代码定义了英文和中文的分词器,英文使用Spacy,中文使用Jieba。

from typing import List, Tuple
from torchtext.vocab import build_vocab_from_iterator

# 读取数据函数
def read_data(file_path: str) -> List[str]:
    with open(file_path, 'r', encoding='utf-8') as f:
        return [line.strip() for line in f]

# 数据预处理函数
def preprocess_data(en_data: List[str], zh_data: List[str]) -> List[Tuple[List[str], List[str]]]:
    processed_data = []
    for en, zh in zip(en_data, zh_data):
        en_tokens = en_tokenizer(en.lower())[:MAX_LENGTH]
        zh_tokens = zh_tokenizer(zh)[:MAX_LENGTH]
        if en_tokens and zh_tokens:  # 确保两个序列都不为空
            processed_data.append((en_tokens, zh_tokens))
    return processed_data

# 构建词汇表
def build_vocab(data: List[Tuple[List[str], List[str]]]):
    en_vocab = build_vocab_from_iterator(
        (en for en, _ in data),
        specials=['<unk>', '<pad>', '<bos>', '<eos>']
    )
    zh_vocab = build_vocab_from_iterator(
        (zh for _, zh in data),
        specials=['<unk>', '<pad>', '<bos>', '<eos>']
    )
    en_vocab.set_default_index(en_vocab['<unk>'])
    zh_vocab.set_default_index(zh_vocab['<unk>'])
    return en_vocab, zh_vocab
  • read_data函数用于从文件中读取数据。
  • preprocess_data函数进行数据预处理,包括分词和截断。
  • build_vocab函数用于构建词汇表。

ps: 由于原训练集存在标点符号等问题,考虑添加数据清洗模块

preprocess_data函数作用

preprocess_data 函数用于对原始的英文和中文数据进行预处理,包括分词和截断操作,将处理后的数据存储为一对对的序列。具体步骤如下:

  1. 定义一个空列表存储处理后的数据:
    processed_data = []

  2. 遍历英文和中文数据对:
    for en, zh in zip(en_data, zh_data):

    • 使用 zip 将英文数据和中文数据对齐,以便同时处理对应的英文句子和中文句子。
  3. 对英文和中文句子进行分词和截断:
    en_tokens = en_tokenizer(en.lower())[:MAX_LENGTH] zh_tokens = zh_tokenizer(zh)[:MAX_LENGTH]

    • 将英文句子转换为小写,并使用 en_tokenizer 进行分词,然后截断到最大长度 MAX_LENGTH
    • 使用 zh_tokenizer 对中文句子进行分词,然后截断到最大长度 MAX_LENGTH
  4. 确保两个序列都不为空:
    if en_tokens and zh_tokens:

    • 仅在英文和中文序列都不为空时,才将其添加到 processed_data 列表中。
  5. 返回处理后的数据:
    return processed_data

build_vocab函数作用

build_vocab 函数用于构建英文和中文的词汇表,包括特殊标记,并设置默认索引。具体步骤如下:

  1. 构建英文词汇表:
    en_vocab = build_vocab_from_iterator( (en for en, _ in data), specials=['<unk>', '<pad>', '<bos>', '<eos>'] )

    • 使用 build_vocab_from_iterator 从英文数据中构建词汇表。
    • specials 参数用于指定特殊标记,包括 <unk> (未知词),<pad> (填充),<bos> (序列开始),<eos> (序列结束)。
  2. 构建中文词汇表:
    zh_vocab = build_vocab_from_iterator( (zh for _, zh in data), specials=['<unk>', '<pad>', '<bos>', '<eos>'] )

    • 使用 build_vocab_from_iterator 从中文数据中构建词汇表,特殊标记与英文词汇表相同。
  3. 设置默认索引:
    en_vocab.set_default_index(en_vocab['<unk>']) zh_vocab.set_default_index(zh_vocab['<unk>'])

    • 设置英文和中文词汇表的默认索引为 <unk> 的索引,用于处理词汇表中不存在的词。
  4. 返回词汇表:
    return en_vocab, zh_vocab

3. 构建数据加载器

from torch.utils.data import Dataset, DataLoader

# 自定义数据集
class TranslationDataset(Dataset):
    def __init__(self, data: List[Tuple[List[str], List[str]]]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# 数据加载器
def create_dataloader(data: List[Tuple[List[str], List[str]]], batch_size: int):
    dataset = TranslationDataset(data)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

# 数据批处理函数
def collate_fn(batch):
    en_batch, zh_batch = zip(*batch)
    en_lens = [len(seq) for seq in en_batch]
    zh_lens = [len(seq) for seq in zh_batch]
    en_padded = pad_sequence([torch.tensor(seq) for seq in en_batch], padding_value=en_vocab['<pad>'])
    zh_padded = pad_sequence([torch.tensor(seq) for seq in zh_batch], padding_value=zh_vocab['<pad>'])
    return en_padded, zh_padded, en_lens, zh_lens

collate_fn 是一个数据批处理函数,用于将一个批次的数据进行处理和整理,主要功能是对不同长度的序列进行填充,使其具有相同的长度,方便后续模型的处理。具体步骤如下:

  1. 解压批次数据:
    en_batch, zh_batch = zip(*batch)
    • 将批次数据解压成两个独立的批次,一个包含英文序列 (en_batch),另一个包含中文序列 (zh_batch)。
  2. 计算每个序列的长度:
    en_lens = [len(seq) for seq in en_batch] zh_lens = [len(seq) for seq in zh_batch]
    • 分别计算每个英文序列和中文序列的长度,并存储在 en_lenszh_lens 列表中。
  3. 将序列转换为张量并填充:
    en_padded = pad_sequence([torch.tensor(seq) for seq in en_batch], padding_value=en_vocab['<pad>']) zh_padded = pad_sequence([torch.tensor(seq) for seq in zh_batch], padding_value=zh_vocab['<pad>'])
    • 使用 pad_sequence 函数对英文序列和中文序列进行填充,使其具有相同的长度。
    • 填充值分别为英文词汇表和中文词汇表中的 <pad> 标记。
    • torch.tensor(seq) 将每个序列转换为张量。
  4. 返回处理后的数据:
    return en_padded, zh_padded, en_lens, zh_lens
    • 返回填充后的英文序列和中文序列,以及每个序列的原始长度。
输入数据
# 数据加载函数

def load_data(train_path: str, dev_en_path: str, dev_zh_path: str, test_en_path: str):
    # 读取训练数据
    train_data = read_data(train_path)
    train_en, train_zh = zip(*(line.split('\t') for line in train_data))
    # 读取开发集和测试集
    dev_en = read_data(dev_en_path)
    dev_zh = read_data(dev_zh_path)
    test_en = read_data(test_en_path)
    # 预处理数据
    train_processed = preprocess_data(train_en, train_zh)
    dev_processed = preprocess_data(dev_en, dev_zh)
    test_processed = [(en_tokenizer(en.lower())[:MAX_LENGTH], []) for en in test_en if en.strip()]
    # 构建词汇表
    global en_vocab, zh_vocab
    en_vocab, zh_vocab = build_vocab(train_processed)

    # 创建数据集
    train_dataset = TranslationDataset(train_processed, en_vocab, zh_vocab)
    dev_dataset = TranslationDataset(dev_processed, en_vocab, zh_vocab)
    test_dataset = TranslationDataset(test_processed, en_vocab, zh_vocab)
    from torch.utils.data import Subset
    # 假设你有10000个样本,你只想用前1000个样本进行测试
    indices = list(range(N))
    train_dataset = Subset(train_dataset, indices)
    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, drop_last=True)

    dev_loader = DataLoader(dev_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, drop_last=True)

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, drop_last=True)
    return train_loader, dev_loader, test_loader, en_vocab, zh_vocab

  • 数据读取

    • 使用Python内置的文件读取方法读取文本数据。
    • zip函数将两个列表打包成元组列表。
  • 数据预处理

    • 分词:使用Spacy和Jieba对英文和中文句子进行分词。
    • 截断:将句子截断为指定的最大长度。
    • 数据清洗:去除空行和不符合要求的句子。
  • 构建词汇表

    • 使用TorchText的build_vocab_from_iterator函数构建词汇表。
    • 特殊符号处理:添加<unk>(未知词)、<pad>(填充)、<bos>(句子开始)和<eos>(句子结束)等特殊符号,并设置默认索引。
  • 自定义数据集和数据加载器

    • 创建自定义的TranslationDataset类,用于存储和访问数据。
    • 使用DataLoader创建数据加载器,方便批量处理数据。
    • collate_fn用于将不同长度的句子处理成统一长度。
  • 子集选择

    • 使用Subset选择数据集的一部分进行测试,便于快速验证模型。

4. 模型定义与训练

PositionalEncoding 类
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        # 创建一个形状为 (max_len, d_model) 的零矩阵
        pe = torch.zeros(max_len, d_model)

        # 生成一个形状为 (max_len, 1) 的位置索引矩阵
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # 计算位置编码的分母项
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # 对偶数维度进行正弦变换
        pe[:, 0::2] = torch.sin(position * div_term)

        # 对奇数维度进行余弦变换
        pe[:, 1::2] = torch.cos(position * div_term)

        # 添加批次维度并转置以符合后续操作
        pe = pe.unsqueeze(0).transpose(0, 1)

        # 注册为持久缓冲区,不作为模型参数更新
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 将位置编码加到输入上
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
TransformerModel 类
class TransformerModel(nn.Module):
    def __init__(self, src_vocab, tgt_vocab, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(TransformerModel, self).__init__()
        # 初始化 Transformer 模型
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)

        # 初始化源语言和目标语言的嵌入层
        self.src_embedding = nn.Embedding(len(src_vocab), d_model)
        self.tgt_embedding = nn.Embedding(len(tgt_vocab), d_model)

        # 初始化位置编码
        self.positional_encoding = PositionalEncoding(d_model, dropout)

        # 定义输出的全连接层
        self.fc_out = nn.Linear(d_model, len(tgt_vocab))

        # 保存词汇表和嵌入维度
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.d_model = d_model

    def forward(self, src, tgt):
        # 调整 src 和 tgt 的维度,变为 (seq_len, batch_size)
        src = src.transpose(0, 1)
        tgt = tgt.transpose(0, 1)

        # 生成源和目标序列的掩码
        src_mask = self.transformer.generate_square_subsequent_mask(src.size(0)).to(src.device)
        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        # 生成源和目标序列的填充掩码
        src_padding_mask = (src == self.src_vocab['<pad>']).transpose(0, 1)
        tgt_padding_mask = (tgt == self.tgt_vocab['<pad>']).transpose(0, 1)

        # 对源和目标序列进行嵌入和位置编码
        src_embedded = self.positional_encoding(self.src_embedding(src) * math.sqrt(self.d_model))
        tgt_embedded = self.positional_encoding(self.tgt_embedding(tgt) * math.sqrt(self.d_model))

        # 通过 Transformer 模型进行前向传播
        output = self.transformer(src_embedded, tgt_embedded,
                                  src_mask, tgt_mask, None, src_padding_mask, tgt_padding_mask, src_padding_mask)

        # 通过全连接层生成最终输出,并调整维度
        return self.fc_out(output).transpose(0, 1)
  1. PositionalEncoding 类:

    • 位置编码用于将位置信息注入到输入张量中,以便模型在处理序列数据时能够区分不同位置的元素。
    • 使用正弦和余弦函数生成位置编码矩阵,并将其注册为模型的持久缓冲区。
    • 在前向传播时,将位置编码加到输入张量上,并应用 dropout。
  2. TransformerModel 类:

    • 定义了一个完整的 Transformer 模型,包括编码器和解码器。
    • 初始化嵌入层,用于将源语言和目标语言的词汇表映射到嵌入空间。
    • 初始化位置编码,用于将位置信息注入到嵌入表示中。
    • 定义前向传播函数 forward,包括以下步骤:
      • 调整源序列和目标序列的维度。
      • 生成源序列和目标序列的掩码,以避免在计算时考虑填充位置。
      • 对源序列和目标序列进行嵌入和位置编码。
      • 通过 Transformer 模型进行前向传播,生成隐藏表示。
      • 通过全连接层生成最终输出,并调整维度。
知识点总结
  • 位置编码 (Positional Encoding): 用于在序列模型中注入位置信息,使得模型能够区分序列中不同位置的元素。
  • Transformer 模型: 一种基于注意力机制的序列到序列模型,能够并行处理序列数据,具有较高的效率和性能。
  • 掩码 (Masking): 用于在计算注意力权重时忽略填充位置或未来位置,确保模型在处理变长序列时能够正确地聚焦在有效位置上。
  • 嵌入层 (Embedding Layer): 用于将离散的词汇表映射到连续的向量空间,以便模型能够处理输入的词汇信息。
训练
训练函数 train
def train(model, iterator, optimizer, criterion, clip):
    model.train()  # 设置模型为训练模式
    epoch_loss = 0  # 初始化epoch损失
    
    for i, batch in enumerate(iterator):
        # 遍历每个批次
        src, trg = batch  # 获取源序列和目标序列
        if src.numel() == 0 or trg.numel() == 0:
            continue  # 跳过空的批次
        
        src, trg = src.to(DEVICE), trg.to(DEVICE)  # 将数据移动到设备(CPU或GPU)
        
        optimizer.zero_grad()  # 清空优化器的梯度
        output = model(src, trg)  # 前向传播,获取模型输出
        
        output_dim = output.shape[-1]  # 获取输出的最后一维大小
        output = output[:, 1:].contiguous().view(-1, output_dim)  # 调整输出张量的形状,去掉起始标记
        trg = trg[:, 1:].contiguous().view(-1)  # 调整目标张量的形状,去掉起始标记
        
        loss = criterion(output, trg)  # 计算损失
        loss.backward()  # 反向传播,计算梯度
        
        clip_grad_norm_(model.parameters(), clip)  # 梯度裁剪,防止梯度爆炸
        optimizer.step()  # 更新模型参数
        
        epoch_loss += loss.item()  # 累加批次损失

    print(f"Average loss for this epoch: {epoch_loss / len(iterator)}")  # 打印平均损失
    return epoch_loss / len(iterator)  # 返回平均损失
详细解释:
  1. 模型设置为训练模式

    • model.train():设置模型为训练模式,启用Dropout和BatchNorm。
  2. 初始化epoch损失

    • epoch_loss = 0:初始化变量用于累加每个批次的损失。
  3. 遍历每个批次

    • for i, batch in enumerate(iterator):遍历数据加载器中的每个批次。
    • src, trg = batch:获取当前批次的源序列和目标序列。
    • if src.numel() == 0 or trg.numel() == 0:检查批次是否为空,如果为空则跳过。
    • src, trg = src.to(DEVICE), trg.to(DEVICE):将数据移动到指定设备(如GPU)。
  4. 前向传播和计算损失

    • optimizer.zero_grad():清空优化器中的梯度。
    • output = model(src, trg):将源序列和目标序列输入模型,得到预测输出。
    • output_dim = output.shape[-1]:获取输出的词汇表大小。
    • output = output[:, 1:].contiguous().view(-1, output_dim):调整输出张量的形状,去掉起始标记,并转换为二维张量。
    • trg = trg[:, 1:].contiguous().view(-1):调整目标张量的形状,去掉起始标记,并转换为一维张量。
    • loss = criterion(output, trg):计算预测输出和目标之间的损失。
    • loss.backward():反向传播,计算梯度。
  5. 梯度裁剪和参数更新

    • clip_grad_norm_(model.parameters(), clip):对梯度进行裁剪,防止梯度爆炸。
    • optimizer.step():更新模型参数。
  6. 累加损失和打印平均损失

    • epoch_loss += loss.item():将当前批次的损失累加到epoch损失中。
    • print(f"Average loss for this epoch: {epoch_loss / len(iterator)}"):打印本epoch的平均损失。
    • return epoch_loss / len(iterator):返回本epoch的平均损失。
评估函数 evaluate
def evaluate(model, iterator, criterion):
    model.eval()  # 设置模型为评估模式
    epoch_loss = 0  # 初始化epoch损失
    with torch.no_grad():  # 关闭梯度计算
        for i, batch in enumerate(iterator):
            # 遍历每个批次
            src, trg = batch  # 获取源序列和目标序列
            if src.numel() == 0 or trg.numel() == 0:
                continue  # 跳过空批次
            
            src, trg = src.to(DEVICE), trg.to(DEVICE)  # 将数据移动到设备(CPU或GPU)
            
            output = model(src, trg, 0)  # 关闭教师强制,前向传播,获取模型输出
            
            output_dim = output.shape[-1]  # 获取输出的最后一维大小
            output = output[:, 1:].contiguous().view(-1, output_dim)  # 调整输出张量的形状,去掉起始标记
            trg = trg[:, 1:].contiguous().view(-1)  # 调整目标张量的形状,去掉起始标记
            
            loss = criterion(output, trg)  # 计算损失
            epoch_loss += loss.item()  # 累加批次损失
        
    return epoch_loss / len(iterator)  # 返回平均损失
详细解释:
  1. 模型设置为评估模式

    • model.eval():设置模型为评估模式,禁用Dropout和BatchNorm。
  2. 初始化epoch损失

    • epoch_loss = 0:初始化变量用于累加每个批次的损失。
  3. 关闭梯度计算

    • with torch.no_grad():关闭梯度计算,节省内存和计算资源。
  4. 遍历每个批次

    • for i, batch in enumerate(iterator):遍历数据加载器中的每个批次。
    • src, trg = batch:获取当前批次的源序列和目标序列。
    • if src.numel() == 0 or trg.numel() == 0:检查批次是否为空,如果为空则跳过。
    • src, trg = src.to(DEVICE), trg.to(DEVICE):将数据移动到指定设备(如GPU)。
  5. 前向传播和计算损失

    • output = model(src, trg, 0):将源序列和目标序列输入模型,得到预测输出,关闭教师强制。
    • output_dim = output.shape[-1]:获取输出的词汇表大小。
    • output = output[:, 1:].contiguous().view(-1, output_dim):调整输出张量的形状,去掉起始标记,并转换为二维张量。
    • trg = trg[:, 1:].contiguous().view(-1):调整目标张量的形状,去掉起始标记,并转换为一维张量。
    • loss = criterion(output, trg):计算预测输出和目标之间的损失。
  6. 累加损失和返回平均损失

    • epoch_loss += loss.item():将当前批次的损失累加到epoch损失中。
    • return epoch_loss / len(iterator):返回本epoch的平均损失。
主函数
import torch
import torch.nn as nn
import torch.optim as optim

# 模型定义(假设有一个initialize_model函数)
INPUT_DIM = len(en_vocab)
OUTPUT_DIM = len(zh_vocab)
EMB_DIM = 128
HID_DIM = 256
N_LAYERS = 2
DROPOUT = 0.5

# 初始化模型
model = initialize_model(INPUT_DIM, OUTPUT_DIM, EMB_DIM, HID_DIM, N_LAYERS, DROPOUT, DEVICE)
print(f'The model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters')

# 定义损失函数
criterion = nn.CrossEntropyLoss(ignore_index=zh_vocab['<pad>'])
# 初始化优化器
optimizer = optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

#todo:可以使用学习率调整方法进行优化

# 训练模型
save_path = '../model/best-model.pt'
train_model(model, train_loader, dev_loader, optimizer, criterion, N_EPOCHS, CLIP, save_path = save_path)

print(f"训练完成!模型已保存到:{save_path}")

这段代码定义并初始化了神经网络模型,并使用Adam优化器进行训练,最后将最佳模型保存。

5. 评价与测试

# 在开发集上进行评价
# model.load_state_dict(torch.load('../model/best-model.pt'))

# 计算BLEU分数
# bleu_score = calculate_bleu(dev_loader, en_vocab, zh_vocab, model, DEVICE)
# print(f'BLEU score = {bleu_score:.2f}')

这里加载最佳模型,并计算BLEU分数进行评价。

# 对测试集进行翻译
save_dir = '../results/submit_task2.txt'
with open(save_dir, 'w') as f:
    translated_sentences = []
    for batch in test_loader:  # 遍历所有数据
        src, _ = batch
        src = src.to(DEVICE)
        translated = translate_sentence(src[0], en_vocab, zh_vocab, model, DEVICE, max_length=50)  # 翻译结果,max_length生成翻译的最大长度
        results = "".join(translated)
        f.write(results + '\n')  # 将结果写入文件
    print(f"翻译完成,结果已保存到{save_dir}")

这部分代码用于对测试集进行翻译,并将翻译结果保存到文件中。

加载术语词典

这段代码的目的是在机器翻译模型的基础上,利用术语词典对翻译结果进行后处理,以提升翻译结果的准确性和一致性。

def load_dictionary(dict_path):
    term_dict = {}
    with open(dict_path, 'r', encoding='utf-8') as f:
        data = f.read()
    data = data.strip().split('\n')
    source_term = [line.split('\t')[0] for line in data]
    target_term = [line.split('\t')[1] for line in data]
    for i in range(len(source_term)):
        term_dict[source_term[i]] = target_term[i]
    return term_dict

这个函数的目的是加载术语词典,并将其存储在一个字典中。术语词典文件dict_path的格式是每行一个术语对,源语言和目标语言术语之间用制表符('\t')分隔。

  1. term_dict:一个空字典,用于存储术语对。
  2. 打开文件并读取其内容。
  3. 将内容按行拆分成列表。
  4. 分别提取源语言术语和目标语言术语。
  5. 将术语对存储在字典term_dict中,其中源语言术语作为键,目标语言术语作为值。
  6. 返回加载好的术语词典。

术语后处理函数

def post_process_translation(translation, term_dict):
    """ 使用术语词典进行后处理 """
    
    translated_words = [term_dict.get(word, word) for word in translation]
    return "".join(translated_words)

这个函数用于根据术语词典对翻译结果进行后处理。

  1. translation:翻译结果,即一个单词列表。
  2. term_dict:术语词典。
  3. 遍历翻译结果中的每个单词,如果该单词在术语词典中存在,则用词典中的对应术语替换,否则保持不变。
  4. 返回处理后的翻译结果。

主程序

dict_path = '../dataset/en-zh.dic'  # 这应该是你的术语词典文件路径
term_dict = load_dictionary(dict_path)
save_dir = '../results/submit_add_dict.txt'
with open(save_dir, 'w') as f:
    translated_sentences = []
    for batch in test_loader:  # 遍历所有数据
        src, _ = batch
        src = src.to(DEVICE)
        translated = translate_sentence(src[0], en_vocab, zh_vocab, model, DEVICE)  # 翻译结果
        results  = post_process_translation(translated, term_dict)
        results = "".join(results)
        f.write(results + '\n')  # 将结果写入文件
        # break
    print(f"翻译完成,结果已保存到{save_dir}")

这部分代码是主程序,用于加载术语词典、翻译句子并进行术语后处理,最后将结果保存到文件中。

  1. 设置术语词典文件路径dict_path并加载词典。
  2. 设置保存翻译结果的文件路径save_dir
  3. 打开保存文件,以写模式。
  4. 遍历测试数据集test_loader中的每个批次。
  5. 从批次中提取源句子并将其放到设备上(如GPU)。
  6. 调用translate_sentence函数对源句子进行翻译。
  7. 调用post_process_translation函数对翻译结果进行术语后处理。
  8. 将处理后的翻译结果转换为字符串并写入文件。
  9. 打印完成信息。
  • 22
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值