简介
上文说到生成对抗网络GAN
能够通过训练学习到数据分布,进而生成新的样本。可是GAN
的缺点是生成的图像是随机的,不能控制生成图像属于何种类别。比如数据集包含飞机、汽车和房屋等类别,原始GAN
并不能在测试阶段控制输出属于哪一类。
为此,研究人员提出了Conditional Generative Adversarial Network
(简称CGAN
),CGAN
的图像生成过程是可控的。
本文包含以下3个方面:
(1)CGAN原理分析
(2)pytorch实现CGAN
(3)视觉结果和损失函数曲线
CGAN
的思想是非常简单的,这也验证了那句话,越简单的想法越伟大!
1、CGAN原理分析
1.1 网络结构
CGAN
是在GAN基础上做的一种改进,通过给原始GAN
的生成器Generator
(下文简记为G
)和判别器Discriminator
(下文简记为D
)添加额外的条件信息,实现条件生成模型。CGAN
原文中作者说额外的条件信息可以是类别标签或者其它的辅助信息,本文使用条件信息(记为y
)作为例子。
CGAN
的核心操作是将条件信息加入到G
和D
中,下面分别进行讨论:
(1)原始GAN
生成器输入是噪声信号,类别标签可以和噪声信号组合作为隐空间表示;
(2)原始GAN
判别器输入是图像数据(真实图像和生成图像),同样需要将类别标签和图像数据进行拼接作为判别器输入。
从上图(来自CGAN论文)中可以看出,CGAN的网络相对于原始GAN网络并没有变化,改变的仅仅是生成器G和判别器D的输入数据,这就使得CGAN可以作为一种通用策略嵌入到其它的GAN网络中。
2.2 损失函数
原始GAN包含一个生成器和一个判别器,其中生成器G和判别器D进行极大极小博弈,损失函数如下:
CGAN添加的额外信息y只需要和x与z进行合并,作为G和D的输入即可,由此得到了CGAN的损失函数如下:
1.3 训练策略与实验结果
CGAN
在mnist
数据集上进行了实验,对于生成器:使用数字的类别y
作为标签,并进行了one-hot
编码,噪声z
来自均均匀分布;噪声z映射到200
维的隐层,类别标签映射到1000
维的隐层,然后进行拼接作为下一层的输入,激活函数使用ReLU
;最后一层使用Sigmoid
函数,生成的样本为784
维(使用的mnist
长宽为28x28=784
)。得到的实验结果如下:
上图中每行是由相同的标签生成的,说明CGAN
的确可以通过给生成器特定的标签,实现特定模式(类别)的生成。CGAN
还做了其它的实验,都证明了CGAN
的模式控制能力。
2、pytorch实现
2.1 生成器实现
CGAN
的生成器输入为噪声z
和类别标签y
的联合输入,所以这里我直接在对DCGAN
的生成器进行改动(DCGAN
的代码和分析参见我之前的文章):
class Generator(nn.Module):
def __init__(self, z_dim, num_classes):
super().__init__()
self.z_dim = z_dim
self.num_classes = num_classes
net = []
# 1:设定每次反卷积的输入和输出通道数
# 卷积核尺寸固定为3,反卷积输出为“SAME”模式
channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64]
channels_out = [512, 256, 128, 64, 3]
active = ["R", "R", "R", "R", "tanh"]
stride = [1, 2, 2, 2, 2]
padding = [0, 1, 1, 1, 1]
for i in range(len(channels_in)):
net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=stride[i], padding=padding[i], bias=False))
if active[i] == "R":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.ReLU())
elif active[i] == "tanh":
net.append(nn.Tanh())
self.generator = nn.Sequential(*net)
def forward(self, x, label):
x = x.unsqueeze(2).unsqueeze(3)
label = label.unsqueeze(2).unsqueeze(3)
data = torch.cat(tensors=(x, label), dim=1)
out = self.generator(data)
return out
2.2 判别器的实现
CGAN
的判别器需要使用图像(生成的和真实的)和类别标签y联合输入,所以这里也是对DCGAN
的判别器第一层进行改动:
class Discriminator(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
net = []
# 1:预先定义
channels_in = [3+self.num_classes, 64, 128, 256, 512]
channels_out = [64, 128, 256, 512, 1]
padding = [1, 1, 1, 1, 0]
active = ["LR", "LR", "LR", "LR", "sigmoid"]
for i in range(len(channels_in)):
net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
kernel_size=4, stride=2, padding=padding[i], bias=False))
if i == 0:
net.append(nn.LeakyReLU(0.2))
elif active[i] == "LR":
net.append(nn.BatchNorm2d(num_features=channels_out[i]))
net.append(nn.LeakyReLU(0.2))
elif active[i] == "sigmoid":
net.append(nn.Sigmoid())
self.discriminator = nn.Sequential(*net)
def forward(self, x, label):
label = label.unsqueeze(2).unsqueeze(3)
label = label.repeat(1, 1, x.size(2), x.size(3))
data = torch.cat(tensors=(x, label), dim=1)
out = self.discriminator(data)
out = out.view(data.size(0), -1)
return out
3、视觉结果和损失函数曲线
自己的数据包含3
类:动漫脸、人脸、鞋。其实当时还选择了其它数据,但是最后发现,在数据集质量不够高时,生成的样本明显不够好,最后筛选才确定了使用这三个数据集。当然,自己的实验结果也非常差!迭代的总体次数为6000
次左右,生成了下面的样本:
上面这个动漫脸完全看不清,人脸中也看不见嘴,下面这个结果更好些:
实际上,结果比较差的主要原因还是在于生成器的结构(不够深,拟合能力不够强),如果换成是近两年的生成器结构,生成的效果肯定会好很多。当然,调参数而是很重要的一个方面,自己也没有进行细致的调参。下面这张图显示了迭代过程中生成的图像的变化:
损失函数没有展示出收敛的趋势,尤其是生成器的损失似乎还在增加: