本文默认读者对GAN有基本的了解,对以下公式:
(1)
了然于胸,其中D代表Discriminator,G代表Generator,表示真实数据的密度函数,一般为噪声的密度函数。GAN在模拟连续变量的分布中表现得不错,但无法直接应用于离散变量,因为Generator往往最终通过softmax函数输出一个关于所有离散点的概率向量,无法生成one-hot形式输出,足够好的D可以轻易的区分出合成数据和真实数据。而如果加入one-hot(argmax(*))这种函数,将导致不可导,使得G无法被训练,另外argmax函数并没有真实的模拟多项分布。
为解决以上问题,大神们提出了很多不同的方案,比如在《GANS for Sequences of discrete Element with the Gumbel-softmax Distribution》一文中,先阐述了使用Gumbel Max可以代替依照概率采样过程,而保留了带优化参数(重参数法),然后为解决不可导问题,将Gumbel Max 替代为 Gumbel Softmax ,引入退火策略,模拟Gumbel Max的效果。
而在《Sequence Generative Adversarial Nets with Policy Gradient》一文中,作者并没有直接的在公式(1)上优化G ,即先完整的生成文本序列X,再将序列送入G。而且是采用了强化学习的思想,结合D,计算每个输出文字的action reward,并使其得模型的回报的数学期望最大。这是什么意思?为什么这样做就能避免求导问题了呢?别急,先看完怎么做,再看为什么。以下,表示G按顺序输出的i到j的文字序列,γ表示全体文字集合。表示在当前状态S下,采取行动a(下一个输出文字)的回报,其实这里的上角标应该是,表示我们在用蒙特卡洛搜索时所采用的policy,但我们一般默认policy就是当前的。那么第t个输出的回报期望为
关于参数求导得
(2)
中应该也有参数,为什么这里忽略了,原文中说在提供的资料里有更多的推导过程,在此暂不深究。同时注意到项在求导时也应该被看做常数)
进而我们给出式(2)的无偏估计
(3)
其中 定义为,即初始状态。
给出以上公式后,看明白的朋友会发现,(3)中并没有明确给出E[·]的采样方法。文中只是简要的说到” the expectation E[·] can be approximated by sampling methods”,别急
紧接着原文给出了大致的算法步骤,我们重点关注对G的训练的部分,可以看到训练过程大致分为两步:
1. Generate a sequence Y1:T = (y1,...,yT ) ∼ Gθ
2. Update generator parameters via policy gradient
第一步通过当前生成一个序列(猜测这里采用了一些不可导的采样方法,比如Gumbel-max,或者直接依概率随机选择)。第二步,通过公式(3)计算导数,然后使用如adam 等方法优化。如此一来就看清楚了,为什么SeqNet可以解决离散GAN的求导问题,可以感性的理解为,传统的GAN,将理解为一个关于参数θ函数,再送入D,反向求导训练,自然会遇到采样函数不可导的问题。而SeqNet把先固化为常数(表现为的近似),再通过强化学习理论构造可微的待优化函数,进行求导训练。