论文分享-- >Adversarial Learning for Neural Dialogue Generation

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Mr_tyting/article/details/80346559

本次要分享的论文是Adversarial Learning for Neural Dialogue Generation,论文链接dialogue-gan,该论文所讲内容和上一篇分享的论文非常类似,都是用GANRL 来做生成的,只不过本篇论文讲的是对话生成,有一些不一样的地方,因此再次分析总结下,就不细分析代码了。

动机

首先还是来讲讲论文所说的动机,传统的seq2seq 方法是以MLE 作为目标函数,虽然在某些任务上取得了不错的成绩,但是也有一些显而易见的缺点:往往生成的句子是乏味的、通用的(低质量的)、短视的、重复性的。

一个好的模型,其生成的句子应该与人类生成句子真假难辨。因此论文采取了GAN 的思想方法,但是传统的Gan 又无法适用于离散的数据上,因此再采用RL的方法,和上一篇分享的论文很像,判别器返回一个reward 给生成器,指导其生成什么样的句子。

模型

生成器:是一个带有attention 机制的 seq2seq 模型。
判别器:是一个Hierachical Neural Network,其参考的论文链接hierarchical,以QA任务为例,一对query answer 样本,如果该answer 是人类生成的,则该样本标签为1,如果为生成器生成的则标签为0。将queryanswer 分别经过相互独立的RNN ,得到两者的最后的state ,然后做concat 操作,作为context_input,将该context_input 喂给一个RNN 做一个二分类的训练。嗯,就是这么简单。

Policy-Gradient-Training
首先需要知道,判别器中返回的reward 具体是什么信息?论文中是指x,y (其中y 时生成器生成的)在判别器中被识别为真的概率值,也即是打分值作为reward 回传给生成器。

这里需要了解在强化学习的几个重要概念中:state,action,policy,rewardstate 为现在已经生成的tokens , action 是下一个即将生成的token , policyGAN 的生成器,reward 为GAN 的判别器所回传的信息。

由强化学习的知识可知,生成器的目标就是使得maximize expected end reward,论文中的公式:

J(θ)=Eyp(y|x)(Q+(x,y)|θ)

这里要特别特别注意:上式中的y 并不是true_data,而是生成器生成的!!!也就是论文中所说的通过policy smaple 出来的。 只在在pretrain 时,才用到true_data

这里写图片描述

这个b 可视为一个baseline,可以这样理解,如果某一个actionreward 很大,则下次生成该sequence 的几率就会增大,但是如果reward 都是正值呢?例如上面判别器给起打分,其分值都为正值,那么每个sequencereward 都是正的,我们希望有个区分,对于reward 较低的sequence 相对来说要抑制他的生成,故减去一个baseline ,使得reward 有正有负。

同样的,上面只是对一个整句进行打分的,由上一篇论文的分析可知,对每步生成的token 进行打分十分有必要的。
论文中对于reward for everystep 有两种解决办法:

  • 蒙特卡洛树搜索法
  • train 一个判别器,使其能对整句打分,也可以对部分句子打分。

论文中也倾向于使用蒙特卡洛搜索树方法,虽然比较耗时。

这里写图片描述

Teacher Forcing

如果我们随机初始化生成器的话,可能存在以下问题:

  • This reward is used to promote or discourage the generator’s own generated sequences.
  • Usually It knows that the generated results are bad, but does not know what results are good.

    一句话,就是生成器可能知道哪些response 是好的、坏的。但是并不知道怎么去生成好的、符合要求的句子,当遇到某些train batch 时,生成器生成的respones 判别器很容易的就能判断出来,这就导致了生成器的loss 突然变得很大。训练不够稳定,容易训飞了。

为了缓解上面所说的问题,论文中提出了Teacher Forcing ,就是给出一些read_data 来指导生成器,告诉生成器哪些response 是好的,让他学着去生成。其实就是用seq2seq 里面的MLE loss 去纠正生成器。嗯,就是这么简单。相当于上一篇论文中pretrain 生成器。

整体流程

这里写图片描述

个人总结

  • 整体难度不大,和上一篇有很多相似之处。要特别注意的时,在train 生成器时不用true_y,只是在teacher Forcing 和判别器处用到true_y
  • 这种方法训练模型相对于传统的seq2seq 可能比较耗时。
阅读更多
想对作者说点什么?

博主推荐

换一批

没有更多推荐了,返回首页