序列对抗网络SeqGAN

SeqGAN源自2016年的论文《SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient》,论文地址:https://arxiv.org/pdf/1609.05473.pdf。其核心是用生成对抗网络处理离散的序列数据。

之前介绍了使用GAN生成图像的方法,由于图像是连续数据,可以使用调整梯度的方法逐步生成图像,而离散数据很难使用梯度更新。在自然语言处理(NLP)中使用GAN生成文字时,由于词索引与词向量转换过程中数据不连续,微调参数可能不起作用;且普通GAN的判别模型只对生成数据整体打分,而文字一般都是逐词(token)生成,因此无法控制细节。SeqGAN借鉴了强化学习(RL)的策略,解决了GAN应用于离散数据的问题。

概念

与基本的GAN算法一样,SeqGAN的基本原理也是迭代训练生成模型G和判别模型D。假设用G生成一个词序列组成句子,由D来判别这个句子是训练集中的真实句子(True data),还是模型生成的句子(Generate);最终目标是用模型G生成以假乱真的句子,让D无法分辨。其操作过程如下:

图片摘自论文

与普通对抗不同的是,在单次操作中,模型多次调用生成模型G和判别模型D。以生成文字为例,右侧的每一个红圈是一个生成词的操作,State为已生成的词串,在生成下一个词Next action时,先调用生成模型G生成多个备选项,然后使用判别模型对各个选项评分(reward),根据评分选择最好的策略Policy,并调整策略模型(Policy Gradient)。

强化学习

SeqGAN主要借鉴了强化学习中的方法,如果不了解强化学习很难看懂论文中的公式和推导,下面先对强化学习做一个简单的介绍。

强化学习的核心是在实践中通过不断试错来学习最好的策略,一般强化学习学到的是一系列决策,其目标是最大化长期收益,例如围棋比赛中当前的操作不仅需要考虑接下来一步的收益,还需要考虑未来多步的收益。

强化学习有几个核心概念:状态s(State)、动作a(Action)、奖励r(Reward)。以生成词系列为例,假设词系列是Y1:T=(y1,...,yt,...,yT),在第t个时间步,状态s是先前已生成的词(y1,...,yt−1),动作a是如何选择要生成的词yt,这也是生成模型的工作Gθ(yt|Y1:t−1),它通过前t-1个词以及模型参数θ来选择下一个词,确定了该词之后,状态也随之改变成s’,对应词(y1,...,yt),以此类推,最终生成的系列(y1,...,yt,...,yT),对序列的评分就是奖励r,如果生成的系列成功地骗过了判别模型D,则得1分,如果被识别出是机器生成的则得0分。

强化学习中还有两个重要概念:动作价值action-value和状态价值state-value。简单地说,动作价值就是在某个状态选择某一动作是好是坏,如果能确定每一个动作对应的价值,就很容易做出决定。动作价值不仅与当前动作有关,还涉及此动作之后一系列动作带来的价值。状态价值也是同理,它表示某个状态的好坏。

SeqGAN原理

SeqGAN中生成模型G的目标是最大化期望奖励reward,简单说就是做出可能是奖励最大的选择,其公式如下:

上式中J是目标函数,E[]是期望,R是序列整体的奖励值,s是状态,θ是生成模型的参数,y是生成的下一个词(动作action),G是生成模型,D是判别模型,Q是动作价值(action-value)。简单地解释公式:希望得到一组生成模型G参数θ;能在s0处做出最佳选择,获取最大回报RT,而如何选择动作又取决于动作的价值Q。

动作价值算法如下:

动作价值是由判别函数D判定的,第T个时间步是最后一个时间步,上式中列出的是判别函数对完整系列的打分。若判别该序列为真实文本,则奖励值R最大。

在生成第t个词时,如何选择(动作a)涉及前期已生成的t-1个词(状态s),以及后续可能的情况,假设此时用模型Gβ生成N个备选词串(Yt:T),再用判别模型D分别对生成的N句(Y1:T)打分,此时使用了蒙特卡洛方法(MC),如下式所示:

这里的生成模型Gβ与前面Gθ通常使用同样的模型参数,有时为了优化速度也可使用不同模型参数。这里使用的蒙特卡洛算法,像下棋一样,不仅要考虑当前一步的最优解,还需要考虑接下来多步组合后的最优解,用于探索此节点以及此节点后续节点(Yt:T)的可能性,也叫roll-out展开,是蒙特卡洛搜索树中的核心技巧。

根据不同的时间步,采取不同的动作价值计算方法:

在最后一个时间步t=T时,直接使用判别函数D计算价值;在其它时间步,使用生成模型Gβ和蒙特卡洛算法生成N个后续备选项,用判别函数D打分并计算分数的均值。

SeqGAN与GAN模型相同,在训练生成器G的同时,判别器D也迭代地更新其参数。

此处公式与GAN相同,即优化判别模型D的参数φ,使其对真实数据Pdata尽量预测为真,对模型Gθ生成的数据尽量预测为假。

主要流程

其主要流程如下:

图片摘自论文

  • 程序定义了基本生成器Gθ,roll-out生成器Gβ,判别器D,以及训练集S。

  • 用MLE(最大似然估计)预训练生成器G。(2行)

  • 用生成器生成的数据和训练集数据预训练判别器D。(4-5行)

  • 进入迭代对抗训练:(6行)

  • 训练生成器(7-13行)
    在每一个时间步计算Q,这是最关键的一步,它利用判别器D、roll-out生成器Gβ以及蒙特卡罗树搜索计算行为价值,然后更新policy gradient策略梯度。

  • 训练判别器(14-17行)
    将训练数据作为正例,生成器生成的样例作为反例训练判别模型D。

代码

推荐以下代码:

TensorFlow代码(官方):https://github.com/LantaoYu/SeqGAN
Pytorch代码:https://github.com/suragnair/seqGAN

其中Pytorch代码比较简单,与论文中描述的模型不完全一致,比如它的G和D都使用GRU作为基础模型,也没有实现rollout逻辑,只是一个简化的版本,优点在于代码简单,适合入门。

  • 12
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值