SeqGAN解读

SeqGAN的概念来自AAAI 2017的SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient一文。

Motivation

如题所示,这篇文章的核心思想是将GAN与强化学习的Policy Gradient算法结合到一起——这也正是D2IA-GAN在处理Generator的优化时使用的技巧。
而该论文的出发点也是意识到了标准的GAN在处理像序列这种离散数据时会遇到的困难,主要体现在两个方面:Generator难以传递梯度更新,Discriminator难以评估非完整序列。
对于前者,作者给出的解决方案对我来说比较熟悉,即把整个GAN看作一个强化学习系统,用Policy Gradient算法更新Generator的参数;对于后者,作者则借鉴了蒙特卡洛树搜索(Monte Carlo tree search,MCTS)的思想,对任意时刻的非完整序列都可以进行评估。

问题定义

根据强化学习的设定,在时刻t,当前的状态s被定义为“已生成的序列”


,记作


,而动作a是接下来要选出的元素


,所以policy模型就是


值得一提的是,这里的policy模型是stochastic,输出的是动作的概率分布;而状态的转移则显然是deterministic,一旦动作确定了,接下来的状态也就确定了。

 

根据Policy Gradient算法,Generator的优化目标是令从初始状态开始的value(累积的reward期望值)最大化:


其中,


是完整序列的reward,


action-value函数,是指“在状态s下选择动作a,此后一直遵循着policy做决策,最终得到的value”。所以对于最右边的式子我们可以这样来理解:在初始状态下,对于policy可能选出的每个y,都计算对应的value,把这些value根据policy的概率分布加权求和,就得到了初始状态的value。

 

action-value函数

 

 

接下来的关键是如何定义

 

因为Discriminator充当了这个强化学习系统的environment,所以Discriminator的输出应当作为reward。但是Discriminator只能对生成的完整序列进行评估,因此目前只能对完整序列状态的value进行定义:

 


这是远远不够的,必须要对任意状态的value都有定义。

蒙特卡洛树搜索(MCTS)

在评估任意时刻的序列时,我们考虑的其实都是它能带来的long-term reward,就像下围棋或象棋一样,每下一步棋都要以全局为考量。在围棋和象棋的求解算法中,MCTS是一个很重要的组成部分,所以作者想到了将它应用到当前的问题。
从名字得知,这种算法属于一种蒙特卡洛方法(Monte Carlo method)——根据维基百科,也称统计模拟方法,是指使用随机数(或更常见的伪随机数)来解决很多计算问题的方法。MCTS正是这样一种基于统计模拟的启发式搜索算法,常用于游戏的决策过程。
MCTS可以无限循环,而每一次循环都由以下4个步骤构成:

  • Selection:从根节点开始,连续选择子节点向下搜索,直至抵达一个叶节点。子节点的选择方法一般采用UCT(Upper Confidence Bound applied to trees)算法,根据节点的“胜利次数”和“游戏次数”来计算被选中的概率,保持了Exploitation和Exploration的平衡,是保证搜索向最优发展的关键。
  • Expansion:在叶节点创建多个子节点。
  • Simulation:在创建的子节点中根据roll-out policy选择一个节点进行模拟,又称为playout或者rollout。它和Selection的区别在于:Selection指的是对于搜索树中已有节点的选择,从根节点开始,有历史统计数据作为参考,使用UCT算法选择每次的子节点;Simulation是简单的模拟,从叶节点开始,用自定义的roll-out policy(可以只是简单的随机概率)来选择子节点,且模拟经过的节点并不加入树中。
  • Backpropagation:根据Simulation的结果,沿着搜索树的路径向上更新节点的统计信息,包括“胜利次数”和“游戏次数”,用于Selection做决策。

在SeqGAN中,实际上只应用了上述的Simulation过程:对于非完整的序列


,以


(等同于Generator)作为roll-out policy,将剩余的T-t个元素模拟出来,这样就可以利用Discriminator进行评估了。为了减小对value估计的误差,会进行N次模拟,对这N个结果取平均值。
最终得到了完整的action-value函数:

 

policy gradient计算

Generator目标函数的梯度可以初步推导为:

 

 

在此基础上,可以去掉期望项,构造一个无偏估计再继续推导:

 


源码对loss的实现为:

  • 111行:x是一个batch生成的所有序列,原来是一个三维数组,这里进行了reshape并转化为one-hot vector,最终得到一个二维数组,每一行以one-hot的形式代表这些生成序列的每一个元素,行数是batch size*sequence length。

  • 113行:最终也是得到一个二维数组,行数与上面相同,每一行代表这些生成序列每个时刻t关于所有候选元素的log概率分布,形如

     

  •  

     

    114行:这里的括号对应110行,运算得到这些序列每个元素被选中的log likelihood,即

  • 116行:这些生成序列每个时刻的reward。

  •  

     

    117行:括号对应于109行的结尾,括号内的运算得到了每个时刻的

     

    ,reduce_sum的意义是,对一个batch中所有序列的所有

    进行总的求和,负号的作用则是把梯度上升问题转化为梯度下降。虽然没有显式地计算期望值,但归因于大量的取样和学习率的存在,最终自动推导出来的梯度是与上述公式相符的。



作者:6e845d5ac37b
链接:https://www.jianshu.com/p/e1b87286bfae
来源:简书
简书著作权归作者所有,任何形式的转载都请联系作者获得授权并注明出处。

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值