GAN网络详解

算法描述
生成对抗网络(Generative Adversarial Nets)模型中的两位博弈方分别有生成网络(Generator)与判别网络(Discriminator)充当。当生成网络G捕捉到样本数据分布,用服从某一分布的噪声z生成一个类似真实训练数据的样本,与真实样本越接近越好;判别网络D一般是一个二分类模型,在本文中D是一个多分类器,用于估计一个样本来自于真实数据的概率,如果样本来自于真实数据,则D输出大概率,否则输出小概率。本文中,判别网络需要在此基础上实现分类功能。

在训练的过程中,需要固定一方,更新另一方的网络状态,如此交替进行。在整个训练的过程中,双方都极力优化自己的网络,从而形成竞争对抗,知道双方达到一个动态的平衡。此时生成网络训练出来的数据与真实数据的分布几乎相同,判别网络也无法再判断出真伪。
本文中生成对抗网络主要分为两部分,生成网络(Generator)与判别网络(Discriminator)。向生成网络内输入噪声,通过多次反卷积的方式得到一个28x28x1的图像作为X_fake,此时将真实的图像X_real与生成器生成的X_fake放入判别网络,判别网络使用多次卷积与Sigmoid函数并通过交叉熵函数计算出判别网络的损失函数D_loss,通过判别网络的损失函数D_loss计算得到生成网络损失函数G_loss。使用G_loss与D_loss对生成网络与判别网络进行参数调整。
在这里插入图片描述

算法流程
1.输入噪声z
2.通过生成网络G得到X_fake=G(z)
3.从数据集中获取真实数据X_real
4.通过判别网络D计算D(real logits)=D(X_real)
5.通过判别网络D计算D(fake logits)=D(X_fake)
6.使用交叉熵函数做损失函数根据D(real logits)计算D(loss real)
7.使用交叉熵函数做损失函数根据D(fake logits)计算D(loss fake)
8.计算判别网络损失函数D_loss=D(loss real)+ D_(loss fake)
9.使用交叉熵函数做损失函数计算生成网络损失函数G_loss
10.使用D_loss对判别网络进行参数调整,使用G_loss对生成网络参数进行调整

它做的是去最大化D的区分度,最小化G(U-net)和real数据集的数据分布,在最小化损失函数时,可以通过梯度下降法来一步步的迭代求解,得到最小化的损失函数,和模型参数值

在原始 GAN 中,无法控制要生成的内容,因为输出仅依赖于随机噪声。我们可以将条件输入 c 添加到随机噪声 Z,以便生成的图像由 G(c,z) 定义。这就是 CGAN[6],通常条件输入矢量 c 与噪声矢量 z 直接连接即可,并且将得到的矢量原样作为发生器的输入,就像它在原始 GAN 中一样。条件 c 可以是图像的类,对象的属性或嵌入想要生成的图像的文本描述,甚至是图片。

在这里插入图片描述

使用 PyTorch 实现一个简单的 GAN 模型。以绘画创作为例,假设我们要创造如下“名画”(以正弦图形为例):
在这里插入图片描述

生成该“艺术画作”的代码如下:

def artist_works(): # painting from the famous artist (real target)
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
paintings = torch.from_numpy(paintings).float()
return paintings
然后,分别定义 G 网络和 D 网络模型:

G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas)

D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
我们设置 Adam 算法进行优化:

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
最后,构建 GAN 迭代训练过程:

plt.ion() # something about continuous plotting

D_loss_history = []
G_loss_history = []
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas G_paintings = G(G_ideas) # fake painting from G (random ideas)

prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob

D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))

D_loss_history.append(D_loss)
G_loss_history.append(G_loss)

opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()

opt_G.zero_grad()
G_loss.backward()
opt_G.step()

if step % 50 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c=’#4AD631’, lw=3, label=‘Generated painting’,)
plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c=’#74BCFF’, lw=3, label=‘standard curve’)
plt.text(-1, 0.75, ‘D accuracy=%.2f (0.5 for D to converge)’ % prob_artist0.data.numpy().mean(), fontdict={‘size’: 8}) plt.text(-1, 0.5, ‘D score= %.2f (-1.38 for G to converge)’ % -D_loss.data.numpy(), fontdict={‘size’: 8})
plt.ylim((-1, 1));plt.legend(loc=‘lower right’, fontsize=10);plt.draw();plt.pause(0.01)

plt.ioff()
plt.show()

采用动态绘图的方式,便于时刻观察 GAN 模型训练情况。
迭代次数为 1 时:在这里插入图片描述

迭代次数为 200 时:
在这里插入图片描述
迭代次数为 1000 时:在这里插入图片描述

迭代次数为 10000 时:在这里插入图片描述

经过 10000 次迭代训练之后,生成的曲线已经与标准曲线非常接近了。D 的 score 也如预期接近 0.5。

  • 5
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

perfect Yang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值