oracle 结果集已耗尽解决方法_ACL最佳长论文,解决NMT中过度校正问题

9669b1f45239f312606974576a0ebf94.png

paper:

Bridging the Gap between Training and Inference for Neural Machine Translation

神经机器翻译(NMT)根据上下文单词来预测下一个词,并按照序列生成整个目标句子。在训练时,它以正确标注的单词作为上下文进行预测,而在推理时,它只能从头开始生成整个序列。训练和推理的差异造成了整个过程的误差累积。除此之外,词级别的训练要求生成的序列和正确标注的序列严格匹配,这会导致对不同于正确标注序列但是合理翻译的句子的过度校正。这篇文章提出了一种新颖的解决方法,在训练阶段,模型不仅从正确标注的序列中采样上下文单词,同时在预测生成的序列中采样上下文单词;在推理阶段,使用句级别最优的预测

Introduction

exposure bias 会导致在训练和推断之间产生偏差,随着目标序列的增长,错误会不断地在序列中积累。

为了解决这个问题,直观来看,模型应该在与推理相同的条件下进行训练,在训练时使用正确标注的词和预测得到的词作为上下文可以作为一种解决方法。

NMT模型通常使用优化交叉熵损失的方法来进行训练,这需要预测序列和正确标注序列在词级别的严格匹配。一旦模型生成偏离正确标注序列的单词,交叉熵损失就会立即纠正错误并且让剩下的预测回归正确标注序列。这会导致一个新的问题:一个句子通常有多个合理的翻译,即是模型产生了与正确标注序列不同的词,也不能说模型犯了错误。但是这种强行校正却可能令预测序列产生语法或者意思上的错误,这个问题被称为过度校正(Over Correction)错误。

文章提出一种减小训练和推理之间 Gap 的方法,并且提高模型对过度校正纠正的能力。文章首先从预测的单词中选择 oracle words,然后从oracle words 和正确标注的单词中进行抽样,所得的结果作为上下文进行训练。oracle words 不仅仅通过单词贪婪搜索来选择,还用句级的评估结果来选择,例如BLEU。在训练开始时,模型以更大的概率选择正确标注的单词作为上下文进行训练,随着模型收敛,模型以更大的概率选 oracle words 作为上下文进行训练。通过这种方法,模型训练从完全引导变为较少引导,并且模型有机会学习处理推理中产生的错误,修正因为过度校正产生的错误。

RNN-based NMT Model

作者提出的方法可以用于任何NMT模型,paper中用 RNN-based NMT (Bahdanau et al., 2015) 做例子说明。

假设输入序列为

,翻译结果为

Encoder

使用双向GRU作为编码器,

代表
代表
的emmbedding

8a2590fd59c64c0e3bf9f0996aaf7d3b.png

Attention

Attention用于提取 source information(称为 source context vector),在第

步,target word
和第
个source word 的相关度为

aad48e5394940d7e163c7f55ffe23cb4.png

source context vector 是 所有source word 的加权和:

2733f78c16b2bb68b7735624e7e5ae2c.png

Decoder

使用GRU的一种变体作为decoder 来展开 target information

步,the target hidden stat
为:

6589c3b12bcb5379c150f37dfc6d40b4.png

目标词表的概率分布

为:

ed0f2024796ca201973bfed17ceb7e07.png

其中

是一个线性变换,
用来映射
, 每一个目标词在
中拥有一个对应维度。

Approach

7b345b41088c178a83fa05af1c0243d3.png

从the ground truth words 和 the previous predicted words(oracle words)中采样上下文单词,其中有两种 oracle words 选择方法,一种方法是使用贪婪搜索算法在单词级别选择oracle words;另一个是在 sentence-level optimum选择 oracle sequence.

The sentence-level oracle provides an option of n-gram matching with the ground truth sequence and hence inherently has the ability of recovering from overcorrection for the alternative context

预测第

个 target word
步骤:

首先需要在第

步选择Oracle词汇

接下来在概率为

的标注词汇
和概率为
的Oracle词汇中进行采样。

最后使用采样所得词代替(6)和(7)中的原始训练词

进行训练和推理。

Oracle Word Selection

在第

步,需要
去预测
,大多数方法都是选择 the groud truth,在这里从
和 the ground truth 中采样来作为context words

Oracle Word 应该是一个和 ground truth 相似或是同义词

Oracle Word 选择有两种:Word-Level Oracle 和 Sentence-Level Oracle

word-level greedy search 可以在每一个时间步找出 Oracle Word,可以通过beam search扩大搜索空间,然后使用 sentencelevel metric (BLEU, GLEU, ROUGE)对候选翻译重新排序,从而进一步优化 Oracle Word,选择出的翻译句子称为 oracle sentence, 翻译中的words 就是 Sentence-level Oracle

1、Word-Level Oracle

b1f35d513b1a92a1c802dc7372286ddb.png

在第

个时间步的decoder,选择 word-level oracle 最直接的方式就是选择公式(9)中词汇分布
中概率最高的词。但是这种方法鲁棒性不足。

0762ae3040ea6185465ba1ddefb8b1e6.png

使用Gumbel-Max方法可以更高效的从候选分布中进行抽样。Gumbel噪声,可以视为一种正则化,可以加在公式(8) 的

上,然后做softmax,最后目标词
的词汇分布近似为:

d306ef4116682f1b844a2f79b0152b32.png

其中,

为一元随机向量
计算得到的Gumbel噪声,
是温度。当
接近0时,
方程与
类似,当
接近无穷大,
方程接近一元分布。最后1-best word 由下面公式选择:

555faf7cf5d3dbe344a45be325426e4b.png

请注意,Gumbel噪声仅用于选择 oracle words ,它不会影响训练的损失函数

2、Sentence-Level Oracle

句子级别的Oracle使用N-gram匹配,可令翻译变得更加灵活,同时选用BLEU作为sentence-level metric。为选择句级别的Oracle,文章使用beam search得到每一个batch全部的候选句,假设beam size = k,由此获得 k-best candidate translations。在进行beam search的过程中,Gumbel噪声可以用于每一个词的生成过程,然后计算预测语句与ground truth的BLEU分数,拥有最高值的预测语句被视为 oracle sentence

,在第
步解码,句级别的Oracle词为:

e330c287afb8a87cb0d1dba8a8238bb4.png

但是预测语句并不一定与标注语句拥有同样的长度,所以文中使用 Force Decoding 的方法来确定两个语句拥有相同的长度。

Force Decoding

对于超过长度的预测语句,此方法会提前选择EOS结束预测,对于短于长度的预测语句,此方法会选择除EOS外最高概率的值继续进行预测。

Sampling with Decay

从 ground truth words

和 oracle word
中采样作为

训练刚开始若从

中频繁采样会导致模型收敛慢,甚至是陷入局部最优;训练后期,若任然从
中大概率采样会导致该模型没有完全学到在推论时必须面对的情况,因此,ground truth的采样概率
是不固定的,
随着训练是递减的,刚开始
,这意味着该模型完全基于 ground truth 进行训练,随着模型逐渐收敛,模型会更频繁地从oracle word 中进行选择。
的衰减函数:

b0488c5cd2911ff717193c14cbf688cc.png

是迭代次数(从0开始),
是超参数,该函数是严格单调递减的,随着训练的进行,提供地面真话的概率
逐渐减小

Training

计算训练损失时,我们不将Gumbel噪声添加到分布中,通过公式(6,7,8,9)得到

的word distribution ,损失函数是基于最大似然估计(MLE)的最大化ground truth序列的概率:

1ad306c33d9ccca4835a5643388443d3.png

其中

为训练数据集中的 sentence pairs number,
为第
个标注句子的长度,
代表第
句第
步的预测概率分布,因此
为预测第
步标注词汇
的概率
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值