关于SeqGan的记录

最近阅读了两篇关于seq gan的论文,以下为两篇论文的记录。

1、SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

        将gan应用于序列生成中会遇到一些问题:1、generator的作用是为了让输出连续,discriminator指导generator更新,而序列生成是离散的,discriminator的输入是采样后的token,往往不可导,discriminator无法更新,从而不能指导generator。2、discriminator往往用于评估完整的sequence,对于一步步慢慢输出的sequence,无法衡量当前与未来整个句子输出的收益。将强化学习结合于gan中,可以解决该问题。

        在gan中,生成模型可以看成强化学习中的agent,其种state是当前生成的句子,action为下一步生成的token。将discriminator作为reward反馈给generator。对每一次生成token的过程,通过蒙特卡罗搜索,可以探索获得整个句子的reward,然后反向传回当前位置。

        SeqGan通过策略梯度进行更新,生成模型的目标为下式,其中使用discriminator作为reward指导generator。

 

        在每一步中,为了关注长期收益,不仅仅要考虑到目前为止序列是否合适,还要关注未来的序列结果。通过蒙特卡罗搜索,其中下式的G与generator相同,指导探索未来的序列。

        为了减少方差和获得更准确的收益,通过该策略搜索N次,进行平均。收益如下式。

        接着使用传统的G与D的更新方式,首先训练判别模型。

        接着更新生成模型

        该论文采用了提前训练生成器,这样可以之后训练生成器的过程更加高效。为了让discriminator训练更加平衡,每次训练discriminator的过程,使用相同数量的正样本与生成的负样本。

        该模型的基本结构如下图。

        其中generator使用循环神经网络lstm,discriminator使用卷积网络,输出输入的句子是否是真实的概率。其中优化的目标是最小化实际label与生成序列概率的交叉损失熵。

        论文做了三组实验,如生成合成句子,其中对训练参数进行了一些研究,因为他会影响seqGan的稳定性,发现训练生成器一次,判别器多次效果会很好。其次还做了一些文本生成,音乐生成的实验。

2、Improving Neural Machine Translation with Conditional Sequence Generative Adversarial Nets

        该论文将seqGan应用于机器翻译中,与上篇论文的结构基本相似。该论文主要提出了将mrt与gan结合的机器翻译模型,除了使用discriminator指导generator的生成,同时添加了句子的bleu值与discriminator中。传统的机器翻译使用最大似然进行训练。[1]mrt提出了使用bleu值进行训练。由于bleu值是计算ngram的精确度,尽管可以将句子的好坏,但是高的bleu值不代表可以生成更好的句子,它不能够完全体现句子的分布。

        该论文使用gan的结构,生成器基于输入句子生成目标句子,而判别器基于输入句子,判断输出句子是不是真实的(即人翻译的)。判别器除了知道生成器生成一个理想的句子分布,同时希望有一个静态的目标来指导生成器,如bleu值。具体结构如下图所示。

        本文提出了一种基于bleu值的生成对抗网络,由三部分构成,生成器,判别器以及bleu值目标值。Generator使用与机器翻译模型相同的结构,如transformer。Discriminator使用cnn结构,来判断生成的句子是真实与否,具体做法如下:

        输入句子x,输入到卷积网络中,经过卷积与最大池化,得到如下的特征向量,生成的句子同样这样表示,得到对应的特征向量。最终输出的句子是否真实通过输入到sigmoid函数中。

 

       Bleu目标值主要用过计算生成的句子与真实的句子的bleu值。整个生成器的目标为最大化期望收益:

        其中R为判别器的输出,由两部分组成,一部分为判别器的概率输出,一部分为bleu值。

        其中lamda为超参数。如果Y不是完整的输出句子,discriminator的值将意义不大在gan的结构中。因此,生成每一步的结果时,采用蒙特卡罗搜索搜索完整句子,获得收益,并反向传回来。最终reward如下,与上篇论文基本相似。

        同样采用最大似然进行预训练得到生成器。其次,预训练判别器,直到判别器准确率达到一定的阈值。在更新梯度的过程中,为了保证连续,对判别器的weight进行了clip。

        最终实验结果。

        该论文也将算法与mrt进行了对比,mrt的损失函数目标为

        其中后一项为bleu值,前面一项为生成该句子的概率。由于搜索句子的空间是指数级的,因此一般部分采样。与本文的gan对比,当参数lambda等于0时,几乎与mrt等同。论文还加入了一个动态的判别器,实验结果如下

 

        在预训练时,需要平衡生成器与判别器。如果生成器太强大,判别器太弱,回导致判别器无法知道生成器。反之,判别器总是惩罚坏的预测,生成器总是不被鼓励,生成器的表现会越来越差。同时,在蒙特卡罗搜索过程中,搜索的次数N不能太小,因为可能给模型指导错误的方向,同时太大的N效果不会变化很明显。

[1] Minimum Risk Training for Neural Machine Translation

[2] SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient

[3] Improving Neural Machine Translation with Conditional Sequence Generative Adversarial Nets

  • 6
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值