paper:
Bridging the Gap between Training and Inference for Neural Machine Translation
神经机器翻译(NMT)根据上下文单词来预测下一个词,并按照序列生成整个目标句子。在训练时,它以正确标注的单词作为上下文进行预测,而在推理时,它只能从头开始生成整个序列。训练和推理的差异造成了整个过程的误差累积。除此之外,词级别的训练要求生成的序列和正确标注的序列严格匹配,这会导致对不同于正确标注序列但是合理翻译的句子的过度校正。这篇文章提出了一种新颖的解决方法,在训练阶段,模型不仅从正确标注的序列中采样上下文单词,同时在预测生成的序列中采样上下文单词;在推理阶段,使用句级别最优的预测。
Introduction
exposure bias 会导致在训练和推断之间产生偏差,随着目标序列的增长,错误会不断地在序列中积累。
为了解决这个问题,直观来看,模型应该在与推理相同的条件下进行训练,在训练时使用正确标注的词和预测得到的词作为上下文可以作为一种解决方法。
NMT模型通常使用优化交叉熵损失的方法来进行训练,这需要预测序列和正确标注序列在词级别的严格匹配。一旦模型生成偏离正确标注序列的单词,交叉熵损失就会立即纠正错误并且让剩下的预测回归正确标注序列。这会导致一个新的问题:一个句子通常有多个合理的翻译,即是模型产生了与正确标注序列不同的词,也不能说模型犯了错误。但是这种强行校正却可能令预测序列产生语法或者意思上的错误,这个问题被称为过度校正(Over Correction)错误。
文章提出一种减小训练和推理之间 Gap 的方法,并且提高模型对过度校正纠正的能力。文章首先从预测的单词中选择 oracle words,然后从oracle words 和正确标注的单词中进行抽样,所得的结果作为上下文进行训练。oracle words 不仅仅通过单词贪婪搜索来选择,还用句级的评估结果来选择,例如BLEU。在训练开始时,模型以更大的概率选择正确标注的单词作为上下文进行训练,随着模型收敛,模型以更大的概率选 oracle words 作为上下文进行训练。通过这种方法,模型训练从完全引导变为较少引导,并且模型有机会学习处理推理中产生的错误,修正因为过度校正产生的错误。
RNN-based NMT Model
作者提出的方法可以用于任何NMT模型,paper中用 RNN-based NMT (Bahdanau et al., 2015) 做例子说明。
假设输入序列为
Encoder
使用双向GRU作为编码器,
Attention
Attention用于提取 source information(称为 source context vector),在第
source context vector 是 所有source word 的加权和:
Decoder
使用GRU的一种变体作为decoder 来展开 target information
第
目标词表的概率分布
其中
Approach
从the ground truth words 和 the previous predicted words(oracle words)中采样上下文单词,其中有两种 oracle words 选择方法,一种方法是使用贪婪搜索算法在单词级别选择oracle words;另一个是在 sentence-level optimum选择 oracle sequence.
The sentence-level oracle provides an option of n-gram matching with the ground truth sequence and hence inherently has the ability of recovering from overcorrection for the alternative context
预测第
首先需要在第
接下来在概率为
最后使用采样所得词代替(6)和(7)中的原始训练词
Oracle Word Selection
在第
Oracle Word 应该是一个和 ground truth 相似或是同义词
Oracle Word 选择有两种:Word-Level Oracle 和 Sentence-Level Oracle
word-level greedy search 可以在每一个时间步找出 Oracle Word,可以通过beam search扩大搜索空间,然后使用 sentencelevel metric (BLEU, GLEU, ROUGE)对候选翻译重新排序,从而进一步优化 Oracle Word,选择出的翻译句子称为 oracle sentence, 翻译中的words 就是 Sentence-level Oracle
1、Word-Level Oracle
在第
使用Gumbel-Max方法可以更高效的从候选分布中进行抽样。Gumbel噪声,可以视为一种正则化,可以加在公式(8) 的
其中,
请注意,Gumbel噪声仅用于选择 oracle words ,它不会影响训练的损失函数
2、Sentence-Level Oracle
句子级别的Oracle使用N-gram匹配,可令翻译变得更加灵活,同时选用BLEU作为sentence-level metric。为选择句级别的Oracle,文章使用beam search得到每一个batch全部的候选句,假设beam size = k,由此获得 k-best candidate translations。在进行beam search的过程中,Gumbel噪声可以用于每一个词的生成过程,然后计算预测语句与ground truth的BLEU分数,拥有最高值的预测语句被视为 oracle sentence
令
但是预测语句并不一定与标注语句拥有同样的长度,所以文中使用 Force Decoding 的方法来确定两个语句拥有相同的长度。
Force Decoding
对于超过长度的预测语句,此方法会提前选择EOS结束预测,对于短于长度的预测语句,此方法会选择除EOS外最高概率的值继续进行预测。
Sampling with Decay
从 ground truth words
训练刚开始若从
Training
计算训练损失时,我们不将Gumbel噪声添加到分布中,通过公式(6,7,8,9)得到
其中