GAN、CGAN、DCGN网络及其应用:Data-Free Learning of Student Networks
前言
论文《Generative Adversarial Networks》及github源代码
论文《Conditional Generative Adversarial Networks》及github源代码
论文《Deep Convolutional Generative Adversarial Networks》
一、GAN、CGAN、DCGAN核心思想及算法
一:GAN网络
1、GAN网络: D(x)输出一个0-1范围内的实数值,用来判断图片是真实图片的概率(x是一张真实图片的矩阵)。G(z)输出生成的一张图片的矩阵(z是一个噪声样本矩阵)讲解与代码
2、训练判别器D()时: 下面公式最大化D时,左右项都有D(),便得到下面算法图中第一个公式(要最小化损失函数,便使得D(x)越接近真实图片标签1,D(G(z))越接近生成的假图片标签0。训练时需要G(z).detach(),因为训练D时G不用更新)。目的是使判别器能判断出真实图片和生成的图片。
3、训练生成器G()时: 下面公式最小化G时,只有右项都有G(),便得到下面算法图中第二个公式(要最小化损失函数,便使得D(G(z))越接近1)。目的是使生成器生成的图片尽可能像真实图片。
4、GAN网络的算法图: 训练判别器和生成器是交替迭代的过程【训练k次判别器(只更新D的参数),再训练一次生成器(只更新G的参数),常取k=1】(CGAN网络的算法类似)。
主要代码如下:
"""准则1:在调用backward()时,只有在前向传播过程中各tensor始终为requires_grad时梯度才会被计算。
各tensor的梯度值被计算完后,只有叶子节点的梯度值能够被保留下来,非叶子节点的被自动清除(.grad为None)"""
"""准则2:优化器只优化(更新)给定模型的参数,且只更新叶子节点的data属性值。如在
d_opt=torch.optim.Adam(D.parameters())后d_opt.step()只会优化D.parameters()的叶子节点的data属性值"""
"""对判别器"""
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
d_loss_real = criterion(real_out, real_label) # 得到真实图片的loss
fake_img = G(z).detach() # 随机噪声放入生成网络中,生成一张假的图片。避免梯度传到G,因为G不用更新
fake_out = D(fake_img) # 判别器判断假的图片,
d_loss_fake = criterion(fake_out, fake_label) # 得到假的图片的loss
d_loss = d_loss_real + d_loss_fake # 损失包括判真损失和判假损失
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
"""对生成器"""
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
fake_img = G(z) # 随机噪声输入到生成器中,得到一副假的图片
output = D(fake_img) # 经过判别器得到的结果
output = output.squeeze()
g_loss = criterion(output, real_label) # 得到的假的图片与真实的图片的label的loss
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
#虽然是优化G(z),但D(fake_img)没有进行detach(),因为要保证前向传播过程中各tensor始终为requires_grad
#此时确实只优化了G而没有优化D,因为torch.optim.Adam(G.parameters(), lr=0.0003)、g_optimizer.step()
二:CGAN网络
1、相当于在原始GAN的基础上加上一个条件:condition,以此来指导G的生成过程。
2、CGAN网络的示意图:其中y就是条件,跟数据x和噪声z同时分别输入进D和G网络中。
三:DCGAN网络
DCGAN网络:在GAN的基础上增加深度卷积网络结构,将生成器和判别器的线性层分别换成转置卷积层和卷积层。