理解 cGANs with Projection Discriminator

Conditional GAN 自从出现以来,产生了很多种 discriminator。cGANs with Projection Discriminator 介绍了一种新的 conditional GAN 判别器设计方法。从实践中来看,它还是比较好用的,被用在了 Spectral Normalization for Generative Adversarial Networks 和 Self-Attention Generative Adversarial Networks 中。

本文将详细介绍该算法的推导过程以及本人的理解,指出它其实是化简版本的 AC-GANs。文中代码来自:crcrpar/pytorch.sngan_projection,该代码为 Spectral Normalization GAN 的 pytorch 版本。

背景

条件GAN的控制器
(a) 中,label 和 data 在输入端拼接起来,一起输入一个神经网络,输出 data 为真实数据的概率。
(b) 和 (a) 很像,只不过是把 label 和 data 在神经网络的中间层拼接起来,输出 data 属于真实数据的概率。
c 图显示的 discriminator 有两部分输出,一部分和传统的 GAN 一样,表示输入数据真实的概率,用 adversarial loss 指导;另一部分为 classification output,表示数据属于各类的概率,用 classification loss 指导。
(d) 图为本文提出的结构,输入图片首先经过网络 ϕ \phi ϕ提取特征,然后把特征分成两路:一路与经过编码的类别标签 y 做点乘,另一路再通过网络 ψ \psi ψ映射成一维向量。最后两路相加,作为神经网络最终的输出。注意这个输出类似于 W-GAN,不经过 sigmoid 函数映射,越大代表越真实。

推导

首先回顾GAN中D对抗损失函数
base function
把 (1) 对 D(x) 求导,令导数为0,得到最优 discriminator:
在这里插入图片描述
即最优的 discriminator 就是数据为 real 的概率比上数据为 real 与数据为 fake 的概率之和。

接下来我们看 conditional discriminator 的损失函数:
在这里插入图片描述
根据 (1) 与 (2) 的推导,可知 conditional GAN 的最优 discriminator 应为:
在这里插入图片描述
而传统的 GAN 输出一个代表数据真实程度的概率, D ( x , y ; θ ) = A ( f ( x , y ; θ ) ) D(x,y;\theta ) = {\rm A}(f(x,y;\theta )) D(x,y;θ)=A(f(x,y;θ))。其中 A \rm A A为 activation function,通常为 sigmoid 函数:f(x)=1/(1+exp(-x))。因此,公式 (4) 可化为:
在这里插入图片描述
接下来思考公式(6)说了一件什么事
首先看公式(6)倒数第二行第二项 log ⁡ [ q d a t a ( x ) / p g ( x ) ] \log [{q_{data}}(x)/{p_g}(x)] log[qdata(x)/pg(x)] 分子为 x 属于真实数据的概率,分母为 x 属于虚假数据的概率。对于一个最优分类器,当输入的 x 为真实数据时,我们希望它的输出值越大越好,而当输入 x 为虚假数据时,我们希望它的输出值越小越好,因此,这一项相当于在判断 x 的真实性
接着看公式 (6) 倒数第二行第一项 log ⁡ [ q d a t a ( y ∣ x ) / p g ( y ∣ x ) ] \log [{q_{data}}(y|x)/{p_g}(y|x)] log[qdata(yx)/pg(yx)] 分子为假如 x 真实,那么 x 属于类别 y 的概率,而分母为假如 x 虚假,x 属于类别 y 的概率。显然,这项在判断 x 的类别是不是我们想要的,这项越大,代表 x 真实时,属于类别 y 的概率越大。

接下来继续推导:

图二

研究公式 (6) 的 log ⁡ [ q d a t a ( y ∣ x ) / p g ( y ∣ x ) ] \log [{q_{data}}(y|x)/{p_g}(y|x)] log[qdata(yx)/pg(yx)] 中的条件概率。对于一个 C 维输出的分类器,我们通常用 softmax 函数计算 x 属于各个类别的概率:
在这里插入图片描述
(7)中的 o j o_j oj为神经网络全连接层的输出,我们可以把它分解成倒数第二层的 feature 乘以一个 C 行的矩阵:
在这里插入图片描述
(8) 中的 ϕ ( x ) \phi(x) ϕ(x) 即为图 2 中的 ϕ \phi ϕ ,把 (8) 带入 (7) 可得:
在这里插入图片描述因此
在这里插入图片描述

v c q d a t a − v c p g = v c v_c^{q_{data}}-v_c^{p_g}=v_c vcqdatavcpg=vc

− ( log ⁡ Z q d a t a ( ϕ ( x ) ) − log ⁡ Z p g ( ϕ ( x ) ) ) + log ⁡ q d a t a ( x ) p g ( x ) = ψ ( ϕ ( x ) ) -\big(\log Z^{q_{data}}(\phi(x))-\log Z^{p_g}(\phi(x))\big)+\log\frac{q_{data}(x)}{p_g(x)}=\psi(\phi(x)) (logZqdata(ϕ(x))logZpg(ϕ(x)))+logpg(x)qdata(x)=ψ(ϕ(x))

得:

f ∗ ( x , y = c ) = v c T ϕ ( x ) + ψ ( ϕ ( x ) ) f^*(x,y=c)=v_c^T\phi(x)+\psi(\phi(x)) f(x,y=c)=vcTϕ(x)+ψ(ϕ(x))

令矩阵 V 中各行向量为 v j T v_j^T vjT , y 为 one-hot label。则最终有:

(12) f ∗ ( x , y ) = y T V ϕ ( x ) + ψ ( ϕ ( x ) ) f^*(x,y)=y^TV\phi(x)+\psi(\phi(x))\tag{12} f(x,y)=yTVϕ(x)+ψ(ϕ(x))(12)

理解

可以把 V 理解成 label 的 embedding 层。这正是图 2 中的网络结构。
如果跳出繁琐的推导过程,直接看公式 (12),我们发现这个最优分类器包含两部分: y T V ϕ ( x ) 和 ψ ( ϕ ( x ) ) y^TV\phi(x) 和 \psi(\phi(x)) yTVϕ(x)ψ(ϕ(x))

对于 ψ ( ϕ ( x ) ) \psi(\phi(x)) ψ(ϕ(x)) ,其实就起 vanilla GAN discriminator 的作用,用于判断数据 x 是否为真实数据

y T V ϕ ( x ) = ( V ϕ ( x ) ) T y y^TV\phi(x)=(V\phi(x))^Ty yTVϕ(x)=(Vϕ(x))Ty ,相当于神经网络的输出 V ϕ ( x ) V\phi(x) Vϕ(x) 与 one-hot label y 的点乘,从而取出来输出部分对应的 target 类的值,这项越大,代表越逼真。我们回顾 multi-class crossentropy 的公式,如果某个数据的 one-hot label 为 y ,网络的输出的概率分布为 p,则对该条数据的损失函数为:

(13) L ( p , y ) = ∑ j = 1 c y j log ⁡ p j \mathcal{L}(p, y)=\sum_{j=1}^c y_j \log p_j\tag{13} L(p,y)=j=1cyjlogpj(13)

注意 (13) 中的 y j y_j yj 大部分为0,只有 ground truth 是 1,因此,对于 multi-class crossentropy,相当于把神经网络输出的概率分布中,对应 ground truth 类的概率提了出来,求了个对数。

y T V ϕ ( x ) = ( V ϕ ( x ) ) T y y^TV\phi(x)=(V\phi(x))^Ty yTVϕ(x)=(Vϕ(x))Ty 中的 ( V ϕ ( x ) ) (V\phi(x)) (Vϕ(x)) 相当于一个特殊的分类网络,它输出的数字没有经过 softmax 映射成概率分布,但仍然可以代表输入数据属于某个类的程度深浅,越大代表越属于某类,越小则反之。

  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值