第2章 GAN初步:2.3 生成手写数字2

2.3.7 模式崩溃

我们刚刚看到的现象,在GAN训练中非常常见,我们称它为模式崩溃(mode collapse)。或者,模式崩塌、模式坍塌。在这里插入图片描述
在MNIST的案例中,我们希望生成器能够创建代表所有10个数字的图像。当模式崩溃发生时,生成器只能生成10个数字中的一个或部分数字,无法达到我们的要求。

发生模式崩溃的原因尚未被完全理解。许多相关的研究正在进行中,我们选取其中一些相对比较成熟的理论进行讨论。

其中一种解释是,在鉴别器学会向生成器提供良好的反馈之前,生成器率先发现一个一直被判定为真实图像的输出。为此,有人提出一些解决方案,比如更频繁地训练鉴别器。但在实践中,这样做往往效果不佳。这就表明,解决问题的关键不仅在于训练的数量,也在于训练的质量。

在我们的例子中,生成器的损失值不断增加(见2.3.6节),表明它的学习没有进展。可能的原因是,鉴别器没有很好地为它提供有效的反馈。这再次表明,训练质量是一个挑战。接下来,我们将试验一些想法,以提高鉴别器对生成器反馈的质量。

2.3.8 改良GAN的训练

在开始改良之前,先备份之前生成手写数字图像的笔记本。
现在,我们试图通过提高GAN的训练质量,解决模式崩溃和图像清晰度低的问题。有的方法我们已经在第1章改良MNIST分类器时用过。

第一个改良是,使用二元交叉熵BCELoss()代替损失函数中的均方误差MSELoss()。我们在1.3.1节讨论过,在神经网络执行分类任务时,二元交叉熵更适用。相比于均方误差,它更大程度地奖励正确的分类结果,同时惩罚错误的结果。在这里插入图片描述
我们可以做的下一个改良是,在鉴别器和生成器中使用LeakyReLU()激活函数。因为我们所预期的输出值范围为0~1,所以我们只会在中间层后使用LeakyReLU(),最后一层仍保留S型激活函数。我们在1.3.2节已经讨论过LeakyReLU()如何解决梯度消失问题。一般来说,这是一种常用的提高神经网络训练质量的方法。

另一种改良是,将神经网络中的信号进行标准化,以确保它们的均值为0。同时,标准化也可以有效地限制信号的方差,避免较大值引起的网络饱和。在1.3.4节中,我们已经看到LayerNorm()如何对训练产生积极的影响。

下面是一个改良后的鉴别器神经网络的代码。
在这里插入图片描述
生成器的代码也进行相同的改良。在这里插入图片描述
还有一种我们之前尝试过的改良是使用Adam优化器(见1.3.3 节)。我们把它同时用于鉴别器和生成器。在这里插入图片描述
让我们看一下采用以上4个改良方案的效果。在这里插入图片描述
遗憾的是,模式崩溃仍然存在。图像的清晰度有所提高,结构更清晰了,但仍然不是一个清楚的数字。

让我们更深入地思考一下如何进一步改良GAN。.

生成过程的起始点是一个种子值。起初,我们用常数值0.5。随后,我们把它改为一个随机值,因为我们知道,对于固定的输入,任何神经网络总会输出相同的结果。也许生成器神经网络觉得,把一个单值转换成784像素来代表一个数字实在太难了

我们可以通过提供更多的输入种子来降低这种难度。比如,我们可以尝试100个输入节点,每个节点都是一个随机值。让我们在代码中更新生成器的神经网络定义。在这里插入图片描述
再看一下效果。在这里插入图片描述
现在图像更清晰了,看起来也更像手写数字了,具体地说有点像0。遗憾的是,所有生成的图像都是相同的,说明我们还没解决模式崩溃问题。

不要灰心丧气——即便是最顶尖的GAN研究者,也同样面临模式崩溃的问题。

如果我们继续思考,不难想到输入生成器的随机种子和输入鉴别器的种子,不应该是一样的。

  • 输入鉴别器的随机图像的像素值,需要在0~1的范围内均匀抽取(uniformly chosen)。这个范围对应真实数据集中图像像素的范围。因为目前的测试是将鉴别器的性能与随机判断进行对比,所以这些值应该是均匀抽取的,而不是从有偏差的正态分布中抽取。
  • 输入生成器的随机值不需要符合0~1的范围。我们知道,标准化一个网络中的信号有助于训练。标准化后的信号会集中在0附近,且方差有限。我们在《Python神经网络编程》中初始化网络链接权重时具体讨论过。这时,从一个平均值为0、方差为1的正态分布中抽取种子更加合理。

现在,让我们分别创建两个生成随机数据的函数。它们看起来很相似,不过一个使用torch.rand(),而另一个使用torch.randn().

  • torch.normal():返回从均值means和标准差std的离散正态分布中抽取随机张量
  • torch.randn():生成满足标准正态分布(0~1)的随机张量
  • torch.rand():返回从区间[0, 1)的均匀分布中抽取的一组随机数
    在这里插入图片描述
    在输入鉴别器时,我们会使用generate_random_image(784);在输入生成器时,我们使用generate_random_seed(100)。

下面是改良后的GAN训练循环。在这里插入图片描述
我们看看效果如何。在这里插入图片描述
太赞了!看上去我们已经解决了模式崩溃问题。现在,生成器可以生成不同的数字。图中的形状看起来一个像8,一个像2,还有一个像3。也有的比较模糊,其中一个看起来既像4又像9。

让我们回顾一下到目前为止的进展。我们训练了一个生成器,并能用它画出手写数字图像。即便没有直接看到任何真实的图像,生成的图像也几乎与训练数据看起来没有区别。这真的很酷。更酷的是,只需改变随机种子,训练过的生成器就可以生成多种不同的数字。

这是一个了不起的成绩。有时候,要解决模式崩溃可能非常困难。很多时候,甚至根本找不到有效的解决方案。

让我们观察一下损失图,看看它们是否能提供一些信息。因为现在使用了BCELoss(),所以这些值并不保证在0~1的范围内。我们需要更新鉴别器和生成器的plot_progress() 函数,删除损失值范围的上限,同时添加更多的水平网格线。在这里插入图片描述
下图所示为鉴别器的训练损失值。在这里插入图片描述
由上图可见,损失值迅速下降到接近于0,并一直保持在很低的位置。训练期间,损失值偶尔发生跳跃。这说明生成器和鉴别器之间仍然没有取得平衡。
下图中是生成器的训练损失值。在这里插入图片描述
损失值先是上升,表示在训练早期生成器落后于鉴别器。之后,损失值下降并保持在3左右。记住,与MSELoss不同,BCELoss没有1.0的上限。

这些损失图看起来有些令人失望,因为损失值的范围更广了。不过,它们仍然好于改良之前的损失图。在之前的图中,鉴别器的损失值在下降时没有太大的波动,生成器的损失值在上升时同样非常工整。这些现象看似令人满意,但不断增加的生成器损失值并不是我们希望的。理想的情况应该是,生成器的损失值只在一个有限的平均值附近变化。

一个很好的问题是,如果我们达到了平衡,BCELoss应该是什么?如果我们运行简单的1010 GAN并达到平衡,由于使用BCELoss,我们会看到生成器和鉴别器的损失值都接近于0.69。读者可以自己试试。对一个完全不确定的分类器使用二元交叉熵,根据数学定义可以计算出,理想的损失值为ln 2或0.693。更多内容可以在附录A中找到。

我们成功地解决了模式崩溃的问题,不过,图像质量还有待改良。我们来看看通过增加训练周期 (epoch)来训练更长时间是否有帮助。我们可以很方便地将GAN训练循环与周期外部循环结合起
来。

以下的图像是训练4个周期后,也就是使用所有训练数据4次的生成效果。总共耗时大约30分钟。在这里插入图片描述
图像看起来好多了。如果读者有时间,可以试试训练8个周期,应该需要1小时左右。

事实上,还有更多改良方法有待我们继续探索。但是,由于我们已经解决了模式崩溃的问题,也可以从生成器获得高质量的图像,因此这里就先告一段落了。

读者可能会问,可以解决模式崩溃是不是因为我们在生成器种子中使用了randn()。如果我们还原之前的代码,在GAN架构中只使用最基本的设置,即便我们将种子改为使用randn(),模式崩溃问题依然不会得到解决。解决问题的是多数或者全部改良的组合作用。 例如,仅为大小为100的生成器种子使用randn()并不能解决模式崩溃问题。读者可以自己试试。

读者可能还希望知道,为什么我们满足于尚未像简单的1010GAN那样达到平衡的生成器和鉴别器,本节的损失图显示,鉴别器的损失值迅速下降到接近于0,并保持在低位,而生成器的损失值仍然很高。
在许多真实的GAN场景中,即使没有达到平衡,仍然可以得到一个可以生成高质量图像的生成器。
我们的最终目标是生成看起来逼真的图像。如果能改善这种平衡,我们当然也应该尝试。我们将继续绘制损失图,因为损失图可以帮我们了解训练的实际状况。例如,MNIST损失图告诉我们训练并不是混乱且不稳定的。

2.3.9 种子实验

到目前为止,我们把GAN的种子当作一个随机数。经过训练后,种子获得了一些有趣的特性。让我们一起来看一下。
假设有种子1(seed1)和种子2(seed2)两个不同的种子。我们可以用它们分别生成图像。
现在,假设seed1和seed2之间有一个中间种子,使用这个种子会生成什么样的图像呢?除此之外,使用在seed1和seed2之间不同位置上的种子又会生成什么样的图像呢?在这里插入图片描述
让我们试一试。首先,我们需要一个以MNIST数据集训练的GAN。我们可以继续使用之前的笔记本。

以下代码将一个随机种子赋值给seed1,以备后用。接着,我们画出生成的图像。在这里插入图片描述
接着,使用seed2重复上述步骤。在这里插入图片描述
并不是每次生成的图像都很清晰。我们需要重复运行上面的代码,直到得到一个较清楚的数字。

以我自己的实验为例,下图是seed1生成的图像,看起来像5。在这里插入图片描述
下图是seed2生成的图像,看起来像3。在这里插入图片描述
接着,让我们通过代码计算seed1与seed2之间距离相等的10个种子。在这里插入图片描述
上面的代码看起来可能比较复杂,不过它做的只是在seed1和seed2之间选择10个点,并以它们为种子生成图像。 下图展示了包含seed1和seed2在内的12个种子生成的图像。在这里插入图片描述
我们可以明显地看出,随着种子从seed1到seed2,图像从5平滑地演变成了3。
让我们再做另外一个实验。如果把种子相加,又会生成什么图像呢?在这里插入图片描述
这段代码很容易理解。一个新的seed3由seed1与seed2相加得到,并输入生成器。在这里插入图片描述
结果图像看起来非常像8。这是合乎情理的,因为我们把5和3重叠,应该也差不多是这个样子。这再次表明了种子一个很好的特性,即种子相加也会造成它们生成的图像的叠加。

我们看到了种子相加的效果。让我们再看看把种子相减会发生什么。
在这里插入图片描述
seed1和seed2的差被输入生成器。在这里插入图片描述
结果图像看起来既像5又像6。它看起来并不完全合乎逻辑,至少不像从5的笔画中减去3的笔画。或许种子的特性并没有这么简单。

让我们再试验一个例子。下图中罗列的图像分别由起始种子(seed1和seed2)、插入 (interpolated)种子、总和种子(seed1+seed2),以及差值种子(seed1−seed2)生成。在这里插入图片描述
我们看到,两个起始种子都生成了看起来像9的图像。在它们之间插入的种子也生成了类似9的图像。两个种子的总和种子生成的图像也是9,这也并不令人意外。令人惊讶的是,差值种子生成的图像却成了8。好奇怪呀!

以下是另一个例子,两个种子生成了非常相似的图像,看起来像5,差值种子却生成了一个非常不同的、看起来很像3的图像。在这里插入图片描述

2.3.10 学习重点

  • 处理单色图像不需要改变神经网络的设计。将二维像素数组简单地展开或重构成一维列表,即可输入鉴别器的输入层。如何做到这一点并不重要,不过要注意保持一致性。
  • 模式崩溃是指一个生成器在有多个可能输出类别的情况下,一直生成单一类别的输出。 模式崩溃是GAN训练中最常见的挑战之一,其原因和解决方法尚未被完全理解,因此是一个相当活跃的研究课题。
  • 着手设计GAN的一个很好开端是,镜像反映生成器和鉴别器的网络架构。这样做的目的是,尽量使它们之间达到平衡。在训练中,其中一方不会领先另一方太多。
  • 实验证据表明,成功训练GAN的关键是质量,而不仅仅是数量。
  • 生成器种子之间的平滑插值会生成平滑的插值图像。将种子相加似乎与图像特征的加法组合相对应。不过,种子相减所生成的图像并不遵循任何直观的规律
  • 理论上,一个经过完美训练的GAN的最优MSE损失(均方误差损失)为0.25,最优BCE损失(二元交叉熵损失)为ln 2或0.693。
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值