文章目录
【生成对抗网络GAN】原理及实现
1. GAN的实现原理
GAN的基本架构如下图所示。
图源
GAN的核心是生成器Generator和判别器Discriminator。二者本质上都是多层感知机网络。
- Generator:负责根据随机信号产生数据(无中生有)
- Discriminator:负责判定Generator生成数据的真伪(火眼金睛)
GAN训练的基本流程:每一轮梯度反向传播过程中,先训练Discriminator,再训练Generator。
具体来说,假设现在进行第k轮训练:
- 先训练Discriminator:先固定Gennerator,即Gennerator的参数此时不更新。将真图像和上一轮产生的假图像 G k − 1 ( z ) G^{k-1}(z) Gk−1(z)拼接在一起,分别打上标签1和0。拼接的图像x输入Discriminator进行打分,得到一个score。根据score和标签的损失函数loss就可以梯度反向传播,更新Discriminator的参数。(相当于训练一个二分类神经网络D)
- 后训练Generator:先固定Discriminator:
discriminator.trainable = False
,即Discriminator的参数此时不能更新。Generator根据输入随机信号z产生假图像G^{k-1}(z),输入Discriminator进行打分score。score和标签1之间的差值作为损失函数loss反向传播,更新Generator的参数。
2. GAN的数学原理(简单了解的同学可以不看这里)
变量命名:
- p d a t a p_{data} pdata:产生数据的概率分布(真图像)
- p g p_{g} pg:随机信号的先验概率(假图像)
2.1. 训练Discriminator的数学原理
2.1.1. 最优解
D
G
∗
(
x
)
D^*_G(x)
DG∗(x)
用
D
G
(
x
)
D_G(x)
DG(x)表示假图像与真图像之间的相似性,即上文的score。
D
G
(
x
)
D_G(x)
DG(x)越大,则Generator以假乱真的能力越强。
固定Generator时,Discriminator的最优解(极大值点)为:
2.1.2. 优化问题
max
V
(
G
,
D
)
\max V(G,D)
maxV(G,D)
Discriminator的优化目标是增强判别真假的能力,因此可以归纳为一个优化问题,即
max
D
V
(
G
,
D
)
\max_D V(G,D)
maxDV(G,D)。因为Generator已经固定,有
x
=
g
(
z
)
x=g(z)
x=g(z),
x
x
x与
z
z
z一一映射。因此,第二项可以用
x
x
x替换
z
z
z。
很明显,
V
(
G
,
D
)
V(G,D)
V(G,D)是一个香农熵(Jesen-Shannon Divergence,JSD)的形式,是为了衡量两种概率分布(这里是,
p
d
a
t
a
p_{data}
pdata和
p
g
p_g
pg)的相似性提出的方法。
D
J
S
(
P
∥
Q
)
=
1
2
D
K
L
(
P
∥
M
)
+
1
2
D
K
L
(
Q
∥
M
)
\operatorname{D_{\mathrm{JS}}}(P \| Q)=\frac{1}{2} D_{\mathrm{KL}}(P \| M)+\frac{1}{2} D_{\mathrm{KL}}(Q \| M)
DJS(P∥Q)=21DKL(P∥M)+21DKL(Q∥M)
其中,
M
=
P
+
Q
2
M=\frac{P+Q}{2}
M=2P+Q。
D
K
L
(
∗
)
D_{\mathrm{KL}}(*)
DKL(∗)表示相对熵(Kullback-Leibler Divergence,KLD)。
D
K
L
(
P
∥
Q
)
=
∫
x
P
(
x
)
ln
P
(
x
)
Q
(
x
)
D_{\mathrm{KL}}(P \| Q)=\int_x P(x) \ln \frac{P(x)}{Q(x)}
DKL(P∥Q)=∫xP(x)lnQ(x)P(x)
我们要找到
V
(
G
,
D
)
V(G,D)
V(G,D)的极大值,因此对式(3)积分号内的数学表达式关于
D
(
x
)
D(x)
D(x)求导,导数为0处即为极大值点:
p
d
a
t
a
(
x
)
1
ln
10
1
D
(
x
)
−
p
g
1
ln
10
1
1
−
D
(
x
)
=
0
p_{data}(x)\frac{1}{\ln10}\frac{1}{D(x)}-p_g\frac{1}{\ln10}\frac{1}{1-D(x)}=0
pdata(x)ln101D(x)1−pgln1011−D(x)1=0
进而有
D
G
∗
(
x
)
=
p
data
(
x
)
p
data
(
x
)
+
p
g
(
x
)
∈
[
0
,
1
]
D_{G}^{*}(\boldsymbol{x})=\frac{p_{\text {data }}(\boldsymbol{x})}{p_{\text {data }}(\boldsymbol{x})+p_{g}(\boldsymbol{x})}\in[0,1]
DG∗(x)=pdata (x)+pg(x)pdata (x)∈[0,1]
易知,
p
g
=
p
d
a
t
a
p_g=p_{data}
pg=pdata时,
D
G
∗
D^*_G
DG∗的值最大。
D
G
,
m
a
x
∗
=
1
2
D^*_{G,max}=\frac{1}{2}
DG,max∗=21
2.2. 训练Generator的数学原理
Discriminator固定时, 令损失函数 C ( G ) = V ( G , D G ∗ ) = max D V ( G , D ) C(G)=V(G, D^*_G)=\max _{D} V(G,D) C(G)=V(G,DG∗)=maxDV(G,D),优化目标转变成使生成假图像尽可能接近真实图像。因此,形成了一个新的优化问题: min G C ( G ) = min G max D V ( G , D ) \min_G C(G)=\min_G \max _{D} V(G, D) minGC(G)=minGmaxDV(G,D)。
那什么时候
C
(
G
)
C(G)
C(G)最小呢?
应该是
p
g
p_g
pg和
p
d
a
t
a
p_{data}
pdata最接近的时候,即生成的假图像最接近真实图像。理想的情况就是和真图像一摸一样。
那么,用香农熵考察
p
g
p_g
pg和
p
d
a
t
a
p_{data}
pdata的相似性:
D
J
S
(
p
d
a
t
a
∥
p
g
)
=
1
2
D
K
L
(
p
d
a
t
a
∥
p
d
a
t
a
+
p
g
2
)
+
1
2
D
K
L
(
p
g
∥
p
d
a
t
a
+
p
g
2
)
=
1
2
(
log
2
+
∫
x
p
d
a
t
a
(
x
)
log
p
d
a
t
a
(
x
)
p
d
a
t
a
+
p
g
(
x
)
d
x
)
+
1
2
(
log
2
+
∫
x
p
g
(
x
)
log
p
g
(
x
)
p
d
a
t
a
+
p
g
(
x
)
d
x
)
=
1
2
(
log
4
+
V
(
G
,
D
G
∗
)
)
=
1
2
(
log
4
+
C
(
G
)
)
\begin{aligned} D_{J S}\left(p_{data} \| p_{g}\right)=& \frac{1}{2} D_{K L}\left(p_{data} \| \frac{p_{data}+p_{g}}{2}\right)+\frac{1}{2} D_{K L}\left(p_{g} \| \frac{p_{data}+p_{g}}{2}\right) \\ =& \frac{1}{2}\left(\log 2+\int_{x} p_{data}(x) \log \frac{p_{data}(x)}{p_{data}+p_{g}(x)} d x\right)+\\ & \frac{1}{2}\left(\log 2+\int_{x} p_{g}(x) \log \frac{p_{g}(x)}{p_{data}+p_{g}(x)} d x\right) \\ =& \frac{1}{2}\left(\log 4+V\left(G, D^{*}_G\right)\right)\\ =& \frac{1}{2}\left(\log 4+C\left(G\right)\right) \end{aligned}
DJS(pdata∥pg)====21DKL(pdata∥2pdata+pg)+21DKL(pg∥2pdata+pg)21(log2+∫xpdata(x)logpdata+pg(x)pdata(x)dx)+21(log2+∫xpg(x)logpdata+pg(x)pg(x)dx)21(log4+V(G,DG∗))21(log4+C(G))
进而,
C
(
G
)
C(G)
C(G)可以表示为
C
(
G
)
=
−
log
(
4
)
+
2
⋅
D
J
S
(
p
data
∥
p
g
)
C(G)=-\log (4)+2 \cdot D_{JS}\left(p_{\text {data }} \| p_{g}\right)
C(G)=−log(4)+2⋅DJS(pdata ∥pg)
因为香农熵
D
J
S
(
∗
)
D_{JS}(*)
DJS(∗)非负,且在
p
g
=
p
d
a
t
a
p_g=p_{data}
pg=pdata时取到
D
J
S
(
p
data
∥
p
g
)
=
0
D_{JS}\left(p_{\text {data }} \| p_{g}\right)=0
DJS(pdata ∥pg)=0,有损失函数最小值
C
∗
(
G
)
=
−
log
4
C^*(G)=-\log4
C∗(G)=−log4。
3. GAN的算法实现(Python)
我将实现的GAN开源在Colab上:Code
4. ArcaneGAN虚拟人脸生成
介绍一个GitHub上好玩的GAN项目《双城之战》风格人脸生成ArcaneGAN