从人工反馈中学写摘要
前言
OpenAI前一段又搞了个大新闻: 他们利用人工反馈的干预,产生出了质量大大超过人力生成的摘要论文。 到底好成什么样呢, 有图为证:
OpenAI请了一些人来做labeler,每次给他们一对摘要,让他们判断哪个比较好一些。可以想见,如果每次都给一对同样是人力生成的摘要,那么大概50%的时间会是第一篇比第二篇更好。这就是上图黑色虚线的位置。纵坐标"Fraction preferred to ref",指的是有多大比例的第一篇比第二篇文章要好。
如果每次第一篇摘要都是模型生成的,那么我们可以比较出模型和人力生成摘要的差距(或优越性)。 最上面黄色的线说明模型摘要的质量远超人工摘要。且模型越大,质量越好。
Seq2Seq
序列-序列模型是NLP模型里的经典架构。它可以用来解决基本上所有语言生成类的问题,摘要生成任务只是其中之一。我们先看一下Seq2Seq的历史和现状。
序列-序列模型
先看图。
这是个经典的序列-序列架构,用来解决Machine translation即机器翻译问题。这里的编码和解码器都画成了LSTM的样子,其实可以替换成任何的模型,比如transformer。 在训练的时候,注意每一步的预测都是在完全知道以前的ground truth的情况下发生的。举例来说,我们在预测"are"的时候,用了"How"最为解码器的输入,在预测"you"的时候,用了"How are"作为输入。这个训练的方式叫做"teacher forcing"。就好像有位老师,不厌其烦的告诉学生习题每一步的正确答案是什么。
在inference的时候,很显然我们事先并不知道答案,所以解码器每一步的预测都要根据它自己上一步的预测做出。比如,在预测"you"的时候,它利用"are"作为输入,但是这个"are"来自于它自己在前一步的预测,如果这个预测错了,比如预测成了"is",那么必然会对当前步正确的预测"you"产生影响。在这里我们可以看到training和inference阶段的mismatch:一个是利用ground truth来做预测,另一个则是完全靠自己的预测往下进行。
传统的Seq2Seq用的是Cross Entropy作为Loss
痛点和解决方案
Exposure bias
刚才我们看到Seq2Seq模型在训练和推断生成时,方式有很明显的不同。这会带来在这两阶段非常大的数据分布的差异。简单来说,模型是在总是知道ground truth的环境下训练出来的(也被形象的称为teacher forcing),但是在实战时,完全不知道ground truth。一直以来有Seq2Seq领域有很多的研究在处理这个注明的Exposure bias问题。
Metrics
一般文本生成任务常用BLEU或者ROUGE做metrics。大致上就是看看在生成句子和参考句子之间有多少n-gram的overlap。传统的Seq2Seq在metrics方面做得很差。首先,Cross Entropy作为损失函数,根本没有直接对metrics进行优化。这个问题常用的解决思路是利用强化学习,特别是policy gradient,它可以直接针对reward进行优化。在这里直接把reward定义为metrics就可以。另外,BLEU或者ROUGE属于比较粗糙的metrics。 文本生成的结果可以千变万化,其实区区的n-gram重合能够描述的。我印象中曾有用生成句子和参考句子的embedding的cosine similairity来做metrics的研究,这就比BLEU这种土办法强了很多倍。
当然了,最强的metrics其实只能由人类来产生。比如一个句子翻译的好不好,有相关经验的人一看便知。有一类做法是利用GAN里的discriminator来产生metircs:它产生的概率D(x)表示生成的句子和真实的结果有多相似。这里的相似可不是有多少的n-gram重合,而是数据分布上的相似。OpenAI的做法和GAN不一样,作者直接利用人的反馈作为label另外作了一个reward模型,这样,弱metrics的问题就解决了。
解决方案
其实强化学习可以同时兼顾以上的两个痛点。针对exposure bias,强化学习在训练的时候就可以避免填鸭式(teacher forcing)。
以策略学习为例,
目标是最大化预期回报,其梯度可以表达为右式。针对Seq2Seq的任务,右边的梯度可以展开为:
这里的r可以被灵活的定义为任何我们愿意优化的目标(BLEU, ROUGE, semantic similarity, etc), 而Action a t a_t at是当前策略下在时间t预测的token,在下一个时间步t+1,它会成为 s t + 1 s_{t+1} st+1的一部分,用来预测下一个token。这个过程彻底避免利用ground truth,和模型在evaluation以及inference的时候完全一致。
图灵测试和GAN
关于传统Seq2seq架构的不足之处,还不止这些。进一步追本溯源的话,其实大多数的文本生成模型,最终是以图灵测试的方式接受人类考验。
就是说,生成的文本由人类来判断是否真实/正确。用数学式子表示,就是
数据抽样于模型 q θ ( x ) q_\theta(x) qθ