1 理论公式
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_Gmax_DV(D,G)=E_{x \sim p_{data}(x)}[logD(x)]+E_{z \sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]
该公式在超分问题上特化为:
m
i
n
θ
G
m
a
x
θ
D
E
I
H
R
∼
p
t
r
a
i
n
(
I
H
R
)
[
l
o
g
D
θ
D
(
I
H
R
)
]
+
E
I
L
R
∼
p
G
(
I
L
R
)
[
l
o
g
(
1
−
D
θ
D
(
G
θ
G
(
I
L
R
)
)
)
]
min_{\theta_G}max_{\theta_D}E_{I^{HR} \sim p_{train}(I^{HR})}[logD_{\theta_D}(I^{HR})]+E_{I^{LR} \sim p_G(I^{LR})}[log(1-D_{\theta_D}(G_{\theta_G}(I^{LR})))]
minθGmaxθDEIHR∼ptrain(IHR)[logDθD(IHR)]+EILR∼pG(ILR)[log(1−DθD(GθG(ILR)))]
其中我将
D
θ
D
(
I
H
R
)
D_{\theta_D}(I^{HR})
DθD(IHR)视为 GT 距离假值
0
0
0 的距离, 将
(
1
−
D
θ
D
(
G
θ
G
(
I
L
R
)
)
)
(1-D_{\theta_D}(G_{\theta_G}(I^{LR})))
(1−DθD(GθG(ILR))) 视为 SR 距离真值
1
1
1 的距离。判别器是朝着向这两个距离之和更大的方向优化,距离之和越大,即判定的越准确;而生成器的优化目标而与判别器相反。
2 训练方式
GAN 网络是一个 m i n m a x minmax minmax 问题,但是在实际应用过程中会将网络分成内外两层,其中外层是 m i n G min_G minG,内层是 m a x D max_D maxD 。但是深度学习中使用的优化方法是随机梯度下降,为了将模型的训练融入进来,于是将内层的 m a x D max_D maxD 问题转换成一个 m i n D min_D minD 问题。具体写作:
m i n G V ( G ) = m i n θ G E I H R ∼ p t r a i n ( I H R ) [ l o g D θ D ( I H R ) ] + E I L R ∼ p G ( I L R ) [ l o g ( 1 − D θ D ( G θ G ( I L R ) ) ) ] m i n D V ( D ) = m i n θ G E I H R ∼ p t r a i n ( I H R ) [ l o g ( 1 − D θ D ( I H R ) ) ] + E I L R ∼ p G ( I L R ) [ l o g D θ D ( G θ G ( I L R ) ) ] \begin{aligned} min_GV(G)&=min_{\theta_G}E_{I^{HR} \sim p_{train}(I^{HR})}[logD_{\theta_D}(I^{HR})]+E_{I^{LR} \sim p_G(I^{LR})}[log(1-D_{\theta_D}(G_{\theta_G}(I^{LR})))] \\ min_DV(D)&=min_{\theta_G}E_{I^{HR} \sim p_{train}(I^{HR})}[log(1-D_{\theta_D}(I^{HR}))]+E_{I^{LR} \sim p_G(I^{LR})}[logD_{\theta_D}(G_{\theta_G}(I^{LR}))] \end{aligned} minGV(G)minDV(D)=minθGEIHR∼ptrain(IHR)[logDθD(IHR)]+EILR∼pG(ILR)[log(1−DθD(GθG(ILR)))]=minθGEIHR∼ptrain(IHR)[log(1−DθD(IHR))]+EILR∼pG(ILR)[logDθD(GθG(ILR))]
3 实现方式
对于在神经网络设计中,对于 GAN 机制的实现主要体现在损失函数的设计上。以下代码只是为了表达思想,不能直接运行。
3.1 方式一
fake_out = netD(fake_img).mean()
g_loss = other_loss + (1 - fake_out)
real_out = netD(real_img).mean()
fake_out = netD(fake_img.detach()).mean()
d_loss = 1 - real_out + fake_out
3.2 方式二
adv_loss = nn.BCEWithLogitsLoss()
fake_out = net_d(fake_img)
real = torch.ones_like(fake_out)
g_loss = other_loss + adv_loss(fake_out, real)
fake_out = net_d(fake_img.detach())
real_out = net_d(real_img)
fake = torch.zeros_like(real_out)
fake_loss = adv_loss(fake_out, fake)
real_loss = adv_loss(real_out, real)
d_loss = (fake_loss + real_loss) / 2