Policy gradient 解决Sequence generation使用GAN时梯度无法更新的问题

GAN

z ∈ R T × d z\in \mathbb{R}^{T\times d} zRT×d
o u t p u t g = G θ ( z ) ∈ R T × V ,  where V is the size of vocabulary. output_g=G_{\theta}(z)\in\mathbb{R}^{T\times V},\ \text{where V is the size of vocabulary.} outputg=Gθ(z)RT×V, where V is the size of vocabulary.
Y = arg max ⁡ ( o u t p u t g ) ∈ R T × 1 → this   operation   will   prevent   generator   from   being   updated. Y=\argmax(output_{g})\in \mathbb{R}^{T\times1}\rightarrow\textbf{\text{this operation will prevent generator from being updated.}} Y=argmax(outputg)RT×1this operation will prevent generator from being updated.
mapping Y to dense vector as Discriminator input by using embedding: \text{mapping Y to dense vector as Discriminator input by using embedding:} mapping Y to dense vector as Discriminator input by using embedding:
i n p u t d = E m b e d d i n g ( Y ) ∈ R T × d m input_d=Embedding(Y)\in\mathbb{R}^{T\times d_m} inputd=Embedding(Y)RT×dm
o u t p u t d = D θ ( i n p u t d ) ∈ R output_d=D_{\theta}(input_d)\in\mathbb{R} outputd=Dθ(inputd)R
L = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ]   ( 1 ) \mathcal{L}=\mathbb{E}_{x\sim p_{data}(x)}[\log D(x)]+\mathbb{E}_{z\sim p_z(z)}[\log(1-D(G(z)))]\ (1) L=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))] (1)

GAN在计算机视觉特别是图像生成方面应用很广泛,但是在文本生成领域应用有比较大的困难。主要是因为GAN是用于生成真实连续的数据(例如图像), 而文本生成是生成离散的数据(对应于词典中的字符)。具体的说,在文本生成中,generator中有argmax操作,该操作是不可导的,在反向传播时,梯度更新会在该操作处停止,从而使Generator无法更新。

Policy Gradient

从以上分析可以得出,Generator无法更新主要是存在不可求导操作( arg max ⁡ \argmax argmax)引起的。解决这个问题可以从两个方面思考,一个是用一个可导函数(神经网络)逼近 arg max ⁡ \argmax argmax操作(这个是我自己猜想的,并没有找到参考文献,不一定可行),另一个是在更新Generator时不使用该操作。policy gradient就是使用的第二种方法,这里的policy指的就是Generator。
Policy gradient是强化学习(Reinforcement learning, RL)中的一种说法,具体的可以参考网上强化学习的资料。RL主要有几个要素agent、environment、action、reward、state。将RL运用到文本生成领域,可以将agent看成generator, environment看成Discriminator,action为将要生成的字符(token),reward为Discriminator给出的打分(生成的句子被判断成为真实句子的概率),state为已经生成的token。
Discriminator的更新与上述方法并无区别,Generator更新主要区别在于loss并非从discriminator端传入(式1),而是将 V ( s 0 ) V(s_0) V(s0)(状态 s 0 s_0 s0的价值函数)作为loss,目标则是最大化价值函数。具体式子如下:
J ( θ ) = E [ R T ∣ s 0 , θ ] = ∑ y 1 ∈ Y G θ ( y 1 ∣ s 0 ) Q G θ D ϕ ( s 0 , y 1 ) J(\theta)=\mathbb{E}[R_T|s_0,\theta]=\sum_{y_1\in\mathcal{Y}}G_{\theta}(y_1|s_0)Q_{G_{\theta}}^{D_{\phi}}(s_0,y_1) J(θ)=E[RTs0,θ]=y1YGθ(y1s0)QGθDϕ(s0,y1)
该式出自论文seqGAN。按照我的理解, s 0 s_0 s0为sequence开始标识,是一个特殊字符,例如<BOS>。 y 1 y_1 y1为即将生成的下一个token。 Y \mathcal{Y} Y为字典。 G θ ( y 1 ∣ s 0 ) G_{\theta}(y_1|s_0) Gθ(y1s0)表示当前状态为 s 0 s_0 s0,在策略 G θ G_{\theta} Gθ下,下一个action为 y 1 y_1 y1的概率。 Q G θ D ϕ ( s 0 , y 1 ) Q_{G_{\theta}}^{D_{\phi}}(s_0,y_1) QGθDϕ(s0,y1)为action-value。文章采用Monte Carlo采样方法计算该值:
Q G θ D ϕ ( s = Y 1 : t − 1 , a = y t ) = 1 N ∑ n = 1 N D ϕ ( Y 1 : T n ) , Y 1 : T n ∈ M C G β ( Y 1 : T ; N ) Q_{G_{\theta}}^{D_{\phi}}(s=Y_{1:t-1},a=y_t)=\dfrac{1}{N}\sum_{n=1}^ND_{\phi}(Y_{1:T}^n), Y_{1:T}^n \in MC^{G_{\beta}}(Y_{1:T};N) QGθDϕ(s=Y1:t1,a=yt)=N1n=1NDϕ(Y1:Tn),Y1:TnMCGβ(Y1:T;N)
上式计算的是state为 Y 1 : t − 1 Y_{1:t-1} Y1:t1时,action为 y t y_t yt的return。这样做的好处在于,不仅可以获得整个序列的reward,还可以获得中间任意位置生成token的reward。也就是说,不但考虑了长期受益,也考虑了短期收益。
我们说过policy gradient可以有效的解决梯度更新无法传递到generator上的问题,观察上式,跟新generator的目标函数 J ( θ ) J({\theta}) J(θ)并没有采用任何不可求导操作,整个式子可以看成是state-action pair ( s , y ) (s,y) (s,y)能获得的Reward的期望( G θ G_{\theta} Gθ返回的是概率值)。
上述讲的很粗糙,要真正的理解其中细节,需要仔细的看看论文,这篇论文花了我很长时间。先是恶补了一下强化学习的相关资料,然后再啃的论文,现在也不敢说百分百读懂了。欢迎一起交流!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值