生成式对抗网络(Generative Adversarial Networks, GAN),简称GAN网络。有人说这是21世纪最让人激动的“发明”,虽然我忘了我是从哪看到的这句话,貌似是发明了卷积神经网络那位大佬说的。我试过以后,对于AI兴趣爱好者来说
确实挺激动的!
对于标题中的cGAN/cDCGAN,小c,全称是conditional,条件的。DC,全称是Deep Convolution,深度卷积。都是GAN网络的一个变种。对于DCGAN与GAN的关系,也很简单,因为最开始GAN网络是用神经网络设计的,而后来出现了计算能力更强的卷积(CNN),训练逻辑相同,只是计算操作不同,当然可以相互替换。
对于原理,网传:一个生成器(Generator),一个判别器(Discriminator),他两相互博弈,相爱相杀,最后产生一个好的结果。。。
What?还要动手吗?
对于此种高端解释,我等菜鸡无法领会,我只想知道网络是怎么训练的?两个部分的输入输出分别是什么?网络如何搭建?Loss如何设计?有了这些,你的程序就可以跑了
还是从代码中理解啥是相爱相杀吧。
先放一张整体原理图,来个大致印象
那个G,就是生成器,那个D,就是判别器。其余就是常规表示网络的结构了,是如何设计的。各位应该发现图中还有个小y,这就是cGAN网络中的c
较常规GAN网络,多了个条件标签
这里想啰嗦一句,这个版本的cGAN在条件标签的处理上,用的是concatenate操作,也就是在某个维度上,直接叠加相关数据,一会在代码中也有显现。其余的还可不可以用别的操作来改善效果,本人很菜,还没有试过。
如图所示,因为用的是MNIST(手写数字体)数据集,每张图片的shape是[1, 28, 28],也就是单通道,分辨率是28x28。又因为是使用神经网络提取特征,所以需要将图片打平操作,所以生成器(G)最后生成的本来应该是一张图片的shape,这里的话就是784,这个数字各位应该不陌生,不多废话。
可以看到,G的输入就是100维的一个随机数,shape是[100, 100],这里生成100张假的数字体图像,对应的label,小y,也就是[100, 10],做了One-Hot编码。然而输出就是[100, 784],经过一些类似imshow等显示图片的函数的时候,在reshape成[100, 1, 28, 28], 就可以显示啦
再看D,判别器,这个就相对简单一些,就是平常看到的分类网络的结构。输入是由G生成的假的图像数据,输出只有两个,真or假,real or fake,用1和0代替结果,shape为[batch, 1],只有一堆0或1作为label。在代码中一看便知,就理解了。
ok,少废话,上代码(下面是完整的,来源也是GitHub的那位老哥的仓库,稍做了些修改,要不在我的环境下直接跑不了)
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# G(z)
class generator(nn.Module):
# initializers
def __init__(self):
super(generator, self).__init__()
self.fc1_1 = nn.Linear(100, 256)
self.fc1_1_bn = nn.BatchNorm1d(256)
self.fc1_2 = nn.Linear(10, 256)
self.fc1_2_bn = nn.BatchNorm1d(256)
self.fc2 = nn.Linear(512, 512)
self.fc2_bn = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, 1024)
self.fc3_bn = nn.BatchNorm1d(1024)
self.fc4 = nn.Linear(1024, 784)
# weight_init
def weight_init(self, mean, std):
for m in self._modules:
normal_init(self._modules[m], mean, std)
# forward method
def forward(self, input, label):
x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
x = torch.cat([x, y], 1)
x = F.relu(self.fc2_bn(self.fc2(x)))
x = F.relu(self.fc3_bn(self.fc3(x)))
x = F.tanh(self.fc4(x))
return x
class discriminator(nn.Module):
# initializers
def __init__(self):
super(discriminator, self).__init__()
self.fc1_1 = nn.Linear(784, 1024)
self.fc1_2 = nn.Linear(10, 1024)
self.fc2 = nn.Linear(2048, 512)
self.fc2_bn = nn.BatchNorm1d(512)
self.fc3 = nn.Linear(512, 256)
self.fc3_bn = nn.BatchNorm1d(256)
self.fc4 = nn.Linear(256, 1)
# weight_init
def weight_init(self, mean,