第3章 卷积GAN和条件式GAN:3.2 条件式GAN

3.2 条件式GAN

之前,我们构建的MNIST GAN可以生成各种不同的输出图像。我们很好地避免了单一化和模式崩溃,它们是设计GAN时的主要挑战。

如果能通过某种方式引导GAN生成多样化的图像,同时又仅限于生成训练数据中的一类图像,那将是非常有价值的。例如,我们可以要求GAN生成不同的、但都代表数字3的图像。

又如,我们用人脸图像进行训练,如果情绪是训练数据中的一个类别,那么我们可以要求GAN只生成具有快乐表情的人脸图像。

3.2.1 条件式GAN架构

让我们构想一下这种架构会是什么样子的。

如果希望让一个训练后的GAN生成器输出一个指定类型的图像,则需要告诉它我们希望的输出类型。这意味着,我们需要将类型作为生成器输入的一部分,如同随机种子一样。

鉴别器的情况更加复杂。以前,它唯一的工作是尝试将真实的图像和生成的图像分开。现在,我们希望它同时学习将类型标签与图像关联起来。 不然,它就无法向生成器提供反馈,生成器也就无法将图像与标签关联起来。这意味着,我们还需要将类型标签与图像一起输入鉴别器。

下图显示的架构是条件式(conditional)GAN。在这里插入图片描述
主要的改变在于,现在生成器和鉴别器的输入都在图像数据的基础上加入了类型标签。

3.2.2 鉴别器

让我们修改之前的全连接MNIST GAN,实现这个架构。

首先,我们需要更新鉴别器,使它可以同时接收输入图像的像素数据和标签信息。一种简单的方法是扩展forward()函数,使它可以同时接收图像张量和标签张量为输入变量,再直接将它们拼接起来。
标签张量就是我们之前在Dataset类中创建的独热张量。

之前的:

    def forward(self, inputs):
        # simply run model
        return self.model(inputs)

修改后:
在这里插入图片描述
通过torch.cat()函数可以方便地将两个张量拼接起来。从Dataset类中返回的图像张量长度为784,标签张量的长度为10,所以拼接起来后的长度为794。

由于我们扩展了输入的大小,因此需要更改第一层神经网络的定义,将预期输入的大小改为784+10。在这里插入图片描述
我们将更新后的输入张量长度写成784+10,而不是794。这是为了方便别人在阅读这段代码时看到这个变化,也能明白这个变化的原因——这是一个很好的编程习惯

对鉴别器的最后一个改动是,在train()函数里需要将标签添加到调用forward()的输入参数中。下面只显示了train()函数的前几行。

在这里插入图片描述
让我们用常规的方法来测试鉴别器。为此,我们需要更新训练循环代码,将额外的标签张量输入train()函数。在这里插入图片描述
我们还需要为随机生成的图像搭配一个随机类别标签。为此,我们创建了一个便利函数generate_random_one_hot(),用来生成一个随机的独热标签向量。在这里插入图片描述
让我们通过损失值来看看鉴别器的效果。在这里插入图片描述
相比原来的鉴别器,修改后的鉴别器的训练损失值并没有太大变化。

3.2.3 生成器

现在,让我们来想象一下生成器。由于要把种子和标签张量输入生成器,因此需要修改forward()函数。我们需要把输入参数拼接起来,再输入神经网络。在这里插入图片描述
网络的第一层需要修改,以便接收10个额外的输入值。在这里插入图片描述
最后,train()函数也需要接收标签输入。在这里插入图片描述
在向生成器输入本身的forward()函数,以及将生成的图像传递给鉴别器的forward()函数时,我们使用的标签张量是相同的。否则,鉴别器无法向生成器提供相关标签的反馈。

3.2.4 训练循环

GAN的主训练循环同样需要修改,输入一个标签张量给鉴别器和生成器。以下代码只显示了周期循环内的内容。在这里插入图片描述
值得注意的是,我们创建了一个变量random_label。这样一来,在用生成的图像训练鉴别器时,我们我可以对生成器和鉴别器输入同一个标签张量。

3.2.5 绘制图像

当生成器训练完成后,我们可以测试用它为指定的几个标签生成图像。我们先在生成器类中添加一个新的plot_images() 方法。在这里插入图片描述
该函数将接收一个整数类型的标签,并将它转换成独热张量,再输入生成器。6个不同的随机种子生成了6幅图像,并绘制在网格中。

3.2.6 条件式GAN的结果

我花了大约1小时30分钟将GAN训练了12个周期。乍看之下,鉴别器损失值与更新之前的损失值没有太大区别。但是如果仔细看会发现新的损失值并不接近0,看起来甚至在增加。这是一个好的现象,因为GAN的理想损失值并不是0。
在这里插入图片描述

生成器损失值看起来同样与改动前的GAN差不多。如果仔细看,会发现均值也不是0,这很好。在这里插入图片描述
这表明,输入额外的标签信息有助于训练GAN。这是合理的,因为鉴别器有了更多有价值的信息,帮助它判断图像是否真实,并反馈给生成器。

最后,我们使用plot_images(9)让GAN生成若干幅9的图像。在这里插入图片描述
成功了!我们的条件式GAN的确生成了几幅数字9的图像。更好的是,这些图像都是不一样的。

下图分别显示了6幅随机生成的数字9、3、1和5的图像。在这里插入图片描述
我们看到,由GAN生成的图像都是我们指定的数字,而且都是不一样的。

能生成指定类型的多样化图像真的很强大。我们可以想象到很多应用,比如生成具有特定情绪表情的人像、具有指定颜色的花朵等。实现这一功能的关键在于,训练数据需要用我们希望生成的类别进行标记。

3.2.7 学习要点

  • 不同于GAN,条件式GAN可以直接生成特定类型的输出。
  • 训练条件式GAN,需要将类别标签分别与图像和种子一起输入鉴别器和生成器。
  • 由条件式GAN生成图像的质量,通常优于由不使用标签信息的同等GAN生成的图像。
  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值