WGAN全称Wasserstein GAN,重点以及和DCGAN的不同之处在于Wasserstein,Wasserstein是个啥呢?网上有很多很棒的解答,这里直接送上传送门KL散度、JS散度、Wasserstein距离
WGAN的作者选择Wasserstein距离来度量真实图像分布和生成图像分布之间的距离,目标即为最小化该距离。尽管Wasserstein距离从公式的形式上来看比较复杂难懂,但是结合代码实现来看其实非常简单。
官方GitHub给出的代码如下:
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_G = -torch.mean(discriminator(gen_imgs))
我们再摆上Wasserstein距离公式:
P
1
P_1
P1表示输入为真实图像的判别器的输出分布,
P
2
P_2
P2表示输入为生成图像的判别器的输出分布,二者合起来即为联合概率分布,那么
γ
\gamma
γ满足这样的联合概率分布,而(
x
x
x,
y
y
y)是随机变量
γ
\gamma
γ所能取到的值的集合,即为一个批次训练中每个input-target pairs得到的输入为真实图像的判别器的输出与输入为生成图像的判别器的输出。之后求两个输出的欧氏距离,再求期望,即求一个批次训练中所有计算得到的欧氏距离的均值再求下界即最小化该平均距离。这大致即为Wasserstein距离在WGAN中所要做的事,所以通过代码实现我们可以发现WGAN的损失其实就是DCGAN的损失去掉log,但是原理却截然不同。而且判别器的输出并没有像DCGAN一样使用sigmoid激活函数,毕竟WGAN的损失并不是基于概率的度量而是距离的度量,所以判别器最后的输出层只是一个简单的线性层。
下面看看伪代码都做了什么:
WGAN的训练方式与DCGAN完全不同,loss采用Wasserstein距离来度量,并且优化算法采用的是RMSProp这里不讲解该方法有兴趣的可以看一下吴恩达的视频,权重得到更新以后需要进行数值clip,将数值固定在[-c,c]区间之内,然后每
n
c
r
i
t
i
c
n_{critic}
ncritic次训练一次生成器,同样优化算法采用的是RMSProp。