最近阅读了两篇关于seq gan的论文,以下为两篇论文的记录。
1、SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
将gan应用于序列生成中会遇到一些问题:1、generator的作用是为了让输出连续,discriminator指导generator更新,而序列生成是离散的,discriminator的输入是采样后的token,往往不可导,discriminator无法更新,从而不能指导generator。2、discriminator往往用于评估完整的sequence,对于一步步慢慢输出的sequence,无法衡量当前与未来整个句子输出的收益。将强化学习结合于gan中,可以解决该问题。
在gan中,生成模型可以看成强化学习中的agent,其种state是当前生成的句子,action为下一步生成的token。将discriminator作为reward反馈给generator。对每一次生成token的过程,通过蒙特卡罗搜索,可以探索获得整个句子的reward,然后反向传回当前位置。
SeqGan通过策略梯度进行更新,生成模型的目标为下式,其中使用discriminator作为reward指导generator。
在每一步中,为了关注长期收益,不仅仅要考虑到目前为止序列是否合适,还要关注未来的序列结果。通过蒙特卡罗搜索,其中下式的G与generator相同,指导探索未来的序列。
为了减少方差和获得更准确的收益,通过该策略搜索N次,进行平均。收益如下式。
接着使用传统的G与D的更新方式,首先训练判别模型。
接着更新生成模型
该论文采用了提前训练生成器,这样可以之后训练生成器的过程更加高效。为了让discriminator训练更加平衡,每次训练discriminator的过程,使用相同数量的正样本与生成的负样本。
该模型的基本结构如下图。
其中generator使用循环神经网络lstm,discriminator使用卷积网络,输出输入的句子是否是真实的概率。其中优化的目标是最小化实际label与生成序列概率的交叉损失熵。
论文做了三组实验,如生成合成句子,其中对训练参数进行了一些研究,因为他会影响seqGan的稳定性,发现训练生成器一次,判别器多次效果会很好。其次还做了一些文本生成,音乐生成的实验。
2、Improving Neural Machine Translation with Conditional Sequence Generative Adversarial Nets
该论文将seqGan应用于机器翻译中,与上篇论文的结构基本相似。该论文主要提出了将mrt与gan结合的机器翻译模型,除了使用discriminator指导generator的生成,同时添加了句子的bleu值与discriminator中。传统的机器翻译使用最大似然进行训练。[1]mrt提出了使用bleu值进行训练。由于bleu值是计算ngram的精确度,尽管可以将句子的好坏,但是高的bleu值不代表可以生成更好的句子,它不能够完全体现句子的分布。
该论文使用gan的结构,生成器基于输入句子生成目标句子,而判别器基于输入句子,判断输出句子是不是真实的(即人翻译的)。判别器除了知道生成器生成一个理想的句子分布,同时希望有一个静态的目标来指导生成器,如bleu值。具体结构如下图所示。
本文提出了一种基于bleu值的生成对抗网络,由三部分构成,生成器,判别器以及bleu值目标值。Generator使用与机器翻译模型相同的结构,如transformer。Discriminator使用cnn结构,来判断生成的句子是真实与否,具体做法如下:
输入句子x,输入到卷积网络中,经过卷积与最大池化,得到如下的特征向量,生成的句子同样这样表示,得到对应的特征向量。最终输出的句子是否真实通过输入到sigmoid函数中。
Bleu目标值主要用过计算生成的句子与真实的句子的bleu值。整个生成器的目标为最大化期望收益:
其中R为判别器的输出,由两部分组成,一部分为判别器的概率输出,一部分为bleu值。
其中lamda为超参数。如果Y不是完整的输出句子,discriminator的值将意义不大在gan的结构中。因此,生成每一步的结果时,采用蒙特卡罗搜索搜索完整句子,获得收益,并反向传回来。最终reward如下,与上篇论文基本相似。
同样采用最大似然进行预训练得到生成器。其次,预训练判别器,直到判别器准确率达到一定的阈值。在更新梯度的过程中,为了保证连续,对判别器的weight进行了clip。
最终实验结果。
该论文也将算法与mrt进行了对比,mrt的损失函数目标为
其中后一项为bleu值,前面一项为生成该句子的概率。由于搜索句子的空间是指数级的,因此一般部分采样。与本文的gan对比,当参数lambda等于0时,几乎与mrt等同。论文还加入了一个动态的判别器,实验结果如下
在预训练时,需要平衡生成器与判别器。如果生成器太强大,判别器太弱,回导致判别器无法知道生成器。反之,判别器总是惩罚坏的预测,生成器总是不被鼓励,生成器的表现会越来越差。同时,在蒙特卡罗搜索过程中,搜索的次数N不能太小,因为可能给模型指导错误的方向,同时太大的N效果不会变化很明显。
[1] Minimum Risk Training for Neural Machine Translation
[2] SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient
[3] Improving Neural Machine Translation with Conditional Sequence Generative Adversarial Nets