条件生成对抗网络-CGAN原理分析与pytorch实现

简介

上文说到生成对抗网络GAN能够通过训练学习到数据分布,进而生成新的样本。可是GAN的缺点是生成的图像是随机的,不能控制生成图像属于何种类别。比如数据集包含飞机、汽车和房屋等类别,原始GAN并不能在测试阶段控制输出属于哪一类。

为此,研究人员提出了Conditional Generative Adversarial Network(简称CGAN),CGAN的图像生成过程是可控的。

本文包含以下3个方面:

(1)CGAN原理分析
(2)pytorch实现CGAN
(3)视觉结果和损失函数曲线

CGAN的思想是非常简单的,这也验证了那句话,越简单的想法越伟大!

1、CGAN原理分析

1.1 网络结构

CGAN是在GAN基础上做的一种改进,通过给原始GAN的生成器Generator(下文简记为G)和判别器Discriminator(下文简记为D)添加额外的条件信息,实现条件生成模型。CGAN原文中作者说额外的条件信息可以是类别标签或者其它的辅助信息,本文使用条件信息(记为y)作为例子。

CGAN的核心操作是将条件信息加入到GD中,下面分别进行讨论:

(1)原始GAN生成器输入是噪声信号,类别标签可以和噪声信号组合作为隐空间表示;
(2)原始GAN判别器输入是图像数据(真实图像和生成图像),同样需要将类别标签和图像数据进行拼接作为判别器输入。

在这里插入图片描述
从上图(来自CGAN论文)中可以看出,CGAN的网络相对于原始GAN网络并没有变化,改变的仅仅是生成器G和判别器D的输入数据,这就使得CGAN可以作为一种通用策略嵌入到其它的GAN网络中。

2.2 损失函数

原始GAN包含一个生成器和一个判别器,其中生成器G和判别器D进行极大极小博弈,损失函数如下:
在这里插入图片描述
CGAN添加的额外信息y只需要和x与z进行合并,作为G和D的输入即可,由此得到了CGAN的损失函数如下:
在这里插入图片描述

1.3 训练策略与实验结果

CGANmnist数据集上进行了实验,对于生成器:使用数字的类别y作为标签,并进行了one-hot编码,噪声z来自均均匀分布;噪声z映射到200维的隐层,类别标签映射到1000维的隐层,然后进行拼接作为下一层的输入,激活函数使用ReLU;最后一层使用Sigmoid函数,生成的样本为784维(使用的mnist长宽为28x28=784)。得到的实验结果如下:
在这里插入图片描述
上图中每行是由相同的标签生成的,说明CGAN的确可以通过给生成器特定的标签,实现特定模式(类别)的生成。CGAN还做了其它的实验,都证明了CGAN的模式控制能力。

2、pytorch实现

2.1 生成器实现

CGAN的生成器输入为噪声z和类别标签y的联合输入,所以这里我直接在对DCGAN的生成器进行改动(DCGAN的代码和分析参见我之前的文章):

class Generator(nn.Module):
    def __init__(self, z_dim, num_classes):
        super().__init__()
        self.z_dim = z_dim
        self.num_classes = num_classes
        net = []
        # 1:设定每次反卷积的输入和输出通道数
        #   卷积核尺寸固定为3,反卷积输出为“SAME”模式
        channels_in = [self.z_dim+self.num_classes, 512, 256, 128, 64]
        channels_out = [512, 256, 128, 64, 3]
        active = ["R", "R", "R", "R", "tanh"]
        stride = [1, 2, 2, 2, 2]
        padding = [0, 1, 1, 1, 1]
        for i in range(len(channels_in)):
            net.append(nn.ConvTranspose2d(in_channels=channels_in[i], out_channels=channels_out[i],
                                          kernel_size=4, stride=stride[i], padding=padding[i], bias=False))
            if active[i] == "R":
                net.append(nn.BatchNorm2d(num_features=channels_out[i]))
                net.append(nn.ReLU())
            elif active[i] == "tanh":
                net.append(nn.Tanh())

        self.generator = nn.Sequential(*net)

    def forward(self, x, label):
        x = x.unsqueeze(2).unsqueeze(3)
        label = label.unsqueeze(2).unsqueeze(3)
        data = torch.cat(tensors=(x, label), dim=1)
        out = self.generator(data)
        return out
2.2 判别器的实现

CGAN的判别器需要使用图像(生成的和真实的)和类别标签y联合输入,所以这里也是对DCGAN的判别器第一层进行改动:

class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

        net = []
        # 1:预先定义
        channels_in = [3+self.num_classes, 64, 128, 256, 512]
        channels_out = [64, 128, 256, 512, 1]
        padding = [1, 1, 1, 1, 0]
        active = ["LR", "LR", "LR", "LR", "sigmoid"]
        for i in range(len(channels_in)):
            net.append(nn.Conv2d(in_channels=channels_in[i], out_channels=channels_out[i],
                                 kernel_size=4, stride=2, padding=padding[i], bias=False))
            if i == 0:
                net.append(nn.LeakyReLU(0.2))
            elif active[i] == "LR":
                net.append(nn.BatchNorm2d(num_features=channels_out[i]))
                net.append(nn.LeakyReLU(0.2))
            elif active[i] == "sigmoid":
                net.append(nn.Sigmoid())
        self.discriminator = nn.Sequential(*net)

    def forward(self, x, label):
        label = label.unsqueeze(2).unsqueeze(3)
        label = label.repeat(1, 1, x.size(2), x.size(3))
        data = torch.cat(tensors=(x, label), dim=1)
        out = self.discriminator(data)
        out = out.view(data.size(0), -1)
        return out

3、视觉结果和损失函数曲线

自己的数据包含3类:动漫脸、人脸、鞋。其实当时还选择了其它数据,但是最后发现,在数据集质量不够高时,生成的样本明显不够好,最后筛选才确定了使用这三个数据集。当然,自己的实验结果也非常差!迭代的总体次数为6000次左右,生成了下面的样本:
在这里插入图片描述
上面这个动漫脸完全看不清,人脸中也看不见嘴,下面这个结果更好些:
在这里插入图片描述
实际上,结果比较差的主要原因还是在于生成器的结构(不够深,拟合能力不够强),如果换成是近两年的生成器结构,生成的效果肯定会好很多。当然,调参数而是很重要的一个方面,自己也没有进行细致的调参。下面这张图显示了迭代过程中生成的图像的变化:
在这里插入图片描述

损失函数没有展示出收敛的趋势,尤其是生成器的损失似乎还在增加:
在这里插入图片描述

  • 6
    点赞
  • 117
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
CGAN是一种生成对抗网络,它可以基于给定的条件生成合成数据。以下是基于PyTorch的CGAN代码框架的一个示例: ```python import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms # 定义生成器 class Generator(nn.Module): def __init__(self, input_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, 128) self.fc2 = nn.Linear(128, 256) self.fc3 = nn.Linear(256, 512) self.fc4 = nn.Linear(512, output_size) self.relu = nn.ReLU() self.tanh = nn.Tanh() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.relu(self.fc3(x)) x = self.tanh(self.fc4(x)) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_size, output_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, output_size) self.relu = nn.ReLU() self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义CGAN模型 class CGAN(nn.Module): def __init__(self, generator, discriminator): super(CGAN, self).__init__() self.generator = generator self.discriminator = discriminator def forward(self, z, c): x_fake = self.generator(torch.cat([z, c], dim=1)) x_real = torch.cat([x_fake, c], dim=1) y_fake = self.discriminator(x_fake) y_real = self.discriminator(x_real) return y_fake, y_real # 定义训练函数 def train_cgan(generator, discriminator, cgan, data_loader, num_epochs, device): generator.to(device) discriminator.to(device) cgan.to(device) criterion = nn.BCELoss() optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) for epoch in range(num_epochs): for i, (x_real, c) in enumerate(data_loader): x_real = x_real.to(device) c = c.to(device) # 训练判别器 optimizer_d.zero_grad() z = torch.randn(x_real.size(0), 100).to(device) y_fake, y_real = cgan(z, c) loss_d = criterion(y_real, torch.ones_like(y_real)) + criterion(y_fake, torch.zeros_like(y_fake)) loss_d.backward(retain_graph=True) optimizer_d.step() # 训练生成器 optimizer_g.zero_grad() z = torch.randn(x_real.size(0), 100).to(device) y_fake, _ = cgan(z, c) loss

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值