你好,我是郭震
这篇从零使用Python,实现生成对抗网络(GAN)的基本版本。
GAN使用两套网络,分别是判别器(D)网络和生成器(G)网络,最重要的是弄清楚每套网络的输入和输出分别是什么,两套网络如何结合在一起,及优化的目标即cost function如何定义。
通俗来讲,两套网络结合的方法,就是G会从D的判分中不断提升生成能力,要知道G最开始的输入全部是噪点,这个思想也是文生图,文生视频的基石。
下面这段代码展示了使用PyTorch框架进行生成对抗网络(GAN)训练的基本流程。
下面这些解释非常重要:
对于判别器网络而言,它的目标是最大化表达式 log(D(x)) + log(1 - D(G(z)))
,其中:
D(x)
是判别器网络对真实图像x
的输出,这个值代表判别器认为图像是真实的概率。D(G(z))
是判别器网络对生成图像G(z)
的输出,这个值代表判别器认为通过生成器从噪声z
生成的图像是真实的概率。log(D(x))
的目标是使得判别器能够尽可能地将真实图像分类为真实(即,使D(x)
接近于1)。log(1 - D(G(z)))
的目标是使得判别器能够将生成的图像分类为假(即,使D(G(z))
接近于0)。
# GAN 训练的基本代码
for epoch in range(num_epochs):
for i, data in enumerate(dataloader, 0):
# 更新判别器网络:maximize log(D(x)) + log(1 - D(G(z)))
# 在真实图像上训练
netD.zero_grad()
real_cpu = data[0].to(device)
batch_size = real_cpu.size(0)
label = torch.full((batch_si