【Datawhale AI夏令营】基于术语词典干预的机器翻译挑战赛 - TASK1

一、项目概述与个人体验
今天,我根据项目文档进行了代码的运行,对项目的运行逻辑以及关键参数的作用有了初步的了解。通过这次实践,我对序列到序列(Seq2Seq)模型和门控循环单元(GRU)模型有了基本的认识,并且掌握了数据处理的基本思路。为了优化结果,我调整了N(选择数据集的前N个样本进行训练)和N_EPOCHS(一次epoch是指将所有数据训练一遍的次数)两个参数。目前,项目运行顺利,没有遇到太大的问题,整体体验良好,为后续深入研究奠定了基础。


二、Baseline概念解析
在数据科学和机器学习领域,Baseline指的是一个简单的模型或解决方案,用于作为比较的标准。对于初次参加Datawhale夏令营的小伙伴来说,Baseline通常是完成比赛或项目的第一个代码实现,其算法相对简单,更侧重于基础功能的实现。本次项目的Baseline是构建和训练一个基于PyTorch的序列到序列(Seq2Seq)机器翻译模型。


三、数据处理
1.1 TranslationDataset类
TranslationDataset类是数据处理的核心,它负责读取数据、制作词典、处理特殊词以及数字化准备。以下是该类的主要功能:

  1. 读取数据:从文件中读取英语和中文句子对。
  2. 制作词典:收集所有英语词和中文字,并给它们编号。
  3. 特殊词处理:确保专业术语被包含在词典中。
  4. 数字化准备:创建从单词到数字的映射。
class TranslationDataset(Dataset):
    def __init__(self, filename, terminology):
        self.data = []
        with open(filename, 'r', encoding='utf-8') as f:
            for line in f:
                en, zh = line.strip().split('\t')
                self.data.append((en, zh))
        
        self.terminology = terminology
        
        # 创建词汇表
        self.en_tokenizer = get_tokenizer('basic_english')
        self.zh_tokenizer = list  # 使用字符级分词
        
        en_vocab = Counter(self.terminology.keys())
        zh_vocab = Counter()
        
        for en, zh in self.data:
            en_vocab.update(self.en_tokenizer(en))
            zh_vocab.update(self.zh_tokenizer(zh))
        
        # 添加特殊标记和常用词到词汇表
        self.en_vocab = ['<pad>', '<sos>', '<eos>'] + list(self.terminology.keys()) + [word for word, _ in en_vocab.most_common(10000)]
        self.zh_vocab = ['<pad>', '<sos>', '<eos>'] + [word for word, _ in zh_vocab.most_common(10000)]
        
        self.en_word2idx = {word: idx for idx, word in enumerate(self.en_vocab)}
        self.zh_word2idx = {word: idx for idx, word in enumerate(self.zh_vocab)}

1.2 collate_fn函数
collate_fn函数的作用是将不同长度的句子整理成一批,以便于模型处理。它主要包括以下步骤:
1.收集一批数据中的英语和中文句子。
2. 将它们填充到相同的长度,便于计算机处理。


四、模型架构
2.1 编码器(Encoder
编码器负责理解输入的英语句子。其主要组件包括:

  1. Embedding:将英语单词转换为数字向量。
  2. RNN(GRU):理解整个句子的含义。
  3. Dropout:防止模型过拟合。
    2.2 解码器(Decoder)
    解码器负责生成中文翻译。其主要组件包括:
    1.Embedding:处理中文字。
  4. RNN(GRU):记住之前翻译的内容。
  5. fc_out:预测下一个中文字。
    实例代码如下:
class Decoder(nn.Module):
    def __init__(self, output_dim, emb_dim, hid_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.rnn = nn.GRU(emb_dim, hid_dim, n_layers, dropout=dropout, batch_first=True)
        self.fc_out = nn.Linear(hid_dim, output_dim)
        self.dropout = nn.Dropout(dropout)def forward(self, input, hidden):
        input = input.unsqueeze(1)
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.rnn(embedded, hidden)
        prediction = self.fc_out(output.squeeze(1))
        return prediction, hidden

2.3 Seq2Seq模型
Seq2Seq模型将编码器和解码器组合在一起,实现完整的翻译功能。该模型使用"教师强制"策略(由teacher_forcing_ratio控制)来指导翻译过程。


五、BLEU评分函数
BLEU是一种广泛使用的机器翻译评估方法,通过比较机器翻译结果与人工翻译的参考文本来评估翻译质量。BLEU的评估过程包括:

  • 准备工作:加载源语言句子和参考翻译句子。
  • 翻译过程:使用模型翻译源语言句子。
  • 计算BLEU分数:比较模型的翻译和人工翻译,给出0到100之间的分数。
    实例代码如下:
from sacrebleu.metrics import BLEU
​
def evaluate_bleu(model, dataset, src_file, ref_file, terminology, device):
    model.eval()
    src_sentences = load_sentences(src_file)
    ref_sentences = load_sentences(ref_file)
    
    translated_sentences = []
    for src in src_sentences:
        translated = translate_sentence(src, model, dataset, terminology, device)
        translated_sentences.append(translated)
    
    bleu = BLEU()
    score = bleu.corpus_score(translated_sentences, [ref_sentences])
    
    return score

BLEU通过精确度检查、完整性检查、长度惩罚和N-gram匹配等方式来评估翻译质量。通过BLEU评分,我们可以客观地评估翻译模型性能,并比较不同模型或跟踪模型训练过程中的进步。

  • 25
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值