(2020,ADA)用有限的数据训练生成对抗网络

Training generative adversarial networks with limited data

公众号:EDPJ

目录

0. 摘要

1. 简介

2. GAN 中的过度拟合

2.1 随机鉴别器增强(stochastic discriminator augmentation)

2.2 设计不泄漏的增强

2.3 我们的增强流程

3. 自适应判别器增强

4. 评估

4.1 从零开始训练

4.2 迁移学习

4.3 小数据集

5. 结论

6. 更广泛的影响

参考

S. 总结

S.1 核心思想

S.2 方法

S.3 其他贡献

S.4 相关内容分析


0. 摘要

使用太少的数据训练生成对抗网络 (GAN) 通常会导致鉴别器过度拟合,从而导致训练发散。 我们提出了一种自适应鉴别器增强(adaptive discriminator augmentation,ADA)机制,可以显着稳定有限数据范围内的训练。 该方法不需要更改损失函数或网络架构,并且适用于从头开始训练以及在另一个数据集上微调现有 GAN。 我们在几个数据集上证明,现在仅使用几千张训练图像就可以获得良好的结果,通常将 StyleGAN2 结果与数量级更少的图像相匹配。 我们希望这能为 GAN 开辟新的应用领域。 我们还发现,广泛使用的 CIFAR-10 实际上是一个有限的数据基准,并将记录 FID 从 5.59 提高到 2.42。

1. 简介

小数据集的关键问题是判别器过度拟合训练样例; 它对生成器的反馈变得毫无意义,训练开始发散。 在深度学习的几乎所有领域,数据集增强是防止过度拟合的标准解决方案。 例如,在旋转、噪声等条件下训练图像分类器会导致这些语义保持失真的不变性增加——这是分类器中非常理想的质量。 相比之下,在类似数据集增强下训练的 GAN 学习生成增强分布。 一般来说,这种对生成样本的增强“泄漏”是非常不可取的。 例如,噪声增强会导致噪声结果,即使数据集中没有噪声。

在本文中,我们演示了如何使用各种增强来防止鉴别器过度拟合,同时确保没有任何增强会泄漏到生成的图像中。 我们首先对防止增强泄漏的条件进行全面分析。 然后,我们设计了一组不同的增强和自适应控制方案,无论训练数据的数量、数据集的属性或确切的训练设置(例如,从头开始训练或迁移学习,都可以使用相同的方法)。

2. GAN 中的过度拟合

我们首先研究可用训练数据的数量如何影响 GAN 训练。

  • 图 1a 显示了我们对 FFHQ 不同子集的基线结果。 在每种情况下,训练都以相同的方式开始,但最终进展停止,FID 开始上升。 训练数据越少,这种情况发生得越早。
  • 图 1b、c 显示了训练期间真实图像和生成图像的鉴别器输出分布。 分布最初重叠,但随着鉴别器变得越来越有信心而逐渐分开,FID 开始恶化的点与分布之间充分重叠的损失一致。 这是过度拟合的强烈迹象,为单独的验证集测量的准确度下降进一步证明了这一点。 我们提出了一种方法来解决这个问题,方法是采用防止鉴别器变得过于自信的多功能增强。

2.1 随机鉴别器增强(stochastic discriminator augmentation)

根据定义,应用于训练数据集的任何增强都将继承到生成的图像。 Zhao 等人最近提出平衡一致性正则化 (balanced consistency regularization,bCR) 作为不把增强泄漏到生成图像的一种解决方案。 CR 指出,应用于同一输入图像的两组增强应该产生相同的输出。Zhao 等人为判别器损失添加 CR 项,并为真实图像和生成图像强制执行判别器一致性,而在训练生成器时不应用增强或一致性损失项(图 2a)。 因此,他们的方法通过使其对 CR 项中使用的增强视而不见来有效地努力推动鉴别器。 然而,实现这个目标为泄漏增强打开了大门,因为生成器可以自由地生成包含它们的图像而不会受到任何惩罚。 在第 4 节中,我们通过实验表明 bCR 确实存在这个问题,因此其效果与数据增强基本相似。 

我们的解决方案与 bCR 类似,因为我们还对喂给鉴别器的所有图像应用了增强。 然而,我们没有添加单独的 CR 损失项,而是仅使用增强图像评估鉴别器,并且在训练生成器时也这样做(图 2b)。 因此,我们称之为随机鉴别器增强的这种方法非常简单。 然而,这种可能性很少受到关注,可能是因为乍一看它是否有效并不明显:如果鉴别器从未看到训练图像的真实情况,则不清楚它是否能正确引导生成器(图 2c)。因此,我们将首先研究这种方法不会将增强泄漏到生成的图像的条件,然后从这种转换中构建一个完整的流程。

2.2 设计不泄漏的增强

鉴别器增强对应于在鉴别器上戴上扭曲的、甚至可能是破坏性的护目镜,并要求生成器生成通过护目镜观察时无法与训练集区分开来的样本。 Bora等人考虑在损坏的测量下训练 GAN 的类似问题,并表明训练隐式地消除损坏并找到正确的分布,只要损坏过程由数据空间上概率分布的可逆变换表示。 我们称这种增强操作为非泄漏(non-leaking)。

这些可逆变换的强大之处在于,它们允许通过仅观察增强集合得出有关基础集相等或不等的结论。 重要的是要理解这并不意味着对单个图像执行的增强需要是可撤销的。 例如,将输入图像设置为 0,90% 的时间这样极端的增强在概率分布意义上是可逆的:即使对于人类来说,通过忽略黑色图像直到只有 10% 保留的图像。 另一方面,从 {0°,90°,180°,270°} 中统一选择的随机旋转不可逆:无法辨别增强后方向之间的差异。

如果仅以 p < 1 的概率执行此旋转(相当于真实图像与其旋转图像按比例混合),情况就会发生变化:这会增加 0° 的相对出现次数,现在只有生成的图像具有正确的方向时,增强分布才能匹配。 类似地,许多其他随机增强可以设计为非泄漏,条件是它们以非零概率被跳过。 附录 C 表明,这可以适用于一大类广泛使用的增强,包括确定性映射(例如,基础变换)、加性噪声、变换群(例如,图像或颜色空间旋转、翻转和缩放)和投影 (例如,切割)。 此外,以固定顺序组合非泄漏增强会产生整体非泄漏增强。

在图 3 中,我们通过三个实际示例验证了我们的分析。

  • 具有对数正态分布的各向同性缩放是固有安全增强的示例,无论 p 的值如何都不会泄漏(图 3a)。
  • 但是,必须至少部分时间跳过上述旋转 90°(图 3b)。当 p 太高时,生成器无法知道生成的图像应该面向哪个方向,最终会随机选择一种可能性。正如所料,问题并不仅仅发生在 p = 1 的极限情况下。在实践中,由于有限采样、网络的有限表示能力、归纳偏差和训练动态。 当 p 保持在 0.85 以下时,生成的图像始终正确定向。在这些区域之间,生成器有时最初会选择错误的方向,然后部分地漂移到正确的分布。
  • 同样的观察结果适用于一系列连续的颜色增强(图 3c)。 这个实验表明,只要 p 保持在 0.8 以下,泄漏就不太可能在实践中发生。

2.3 我们的增强流程

鉴于 RandAugment 在图像分类任务中的成功,我们从假设一组最大程度不同的增强是有益的开始。 我们考虑一个包含 18 个转换的流程,这些转换分为 6 个类别:像素块传输(x 翻转、90 度旋转、整合平移)、更一般的几何转换、颜色转换、图像空间过滤、加性噪声和剪切。附录 B 中给出了各个增强的详细信息。请注意,我们也在训练生成器时执行增强(图 2b),这要求增强是可鉴别的。 我们通过使用深度学习框架提供的标准可鉴别原始配置来实现它们来实现这一点。

在训练期间,我们使用一组固定顺序的预定义转换来处理显示给鉴别器的每个图像。 增强的强度由标量 p ∈ [0,1] 控制,因此每个变换都以概率 p 应用或以概率 1-p 跳过。 我们总是对所有转换使用相同的 p 值。 随机化是针对每个增强和 minibatch.中的每个图像单独完成的。 鉴于流程中有许多增强,即使 p 值相当小,鉴别器也不太可能看到干净的图像(图 2c)。 尽管如此,只要 p 保持在实际安全门限以下,生成器就会被引导只生成干净的图像。

在图 4 中,我们通过对不同增强类别和不同大小的数据集遍历 p 来研究随机鉴别器增强的有效性。 我们观察到它在许多情况下可以显着改善结果。 然而,最佳增强强度在很大程度上取决于训练数据的数量,并且并非所有增强类别在实践中都同样有用。

  • 对于 2k 的训练集,绝大多数的好处来自像素块传输(blitting)和几何(geometric)变换。 颜色变换适度有益,而图像空间过滤、噪声和剪切(cutout)不是特别有用。 在这种情况下,最好的结果是使用强增强获得的。 曲线还表明,当 p→1 时,一些增强变得泄漏。
  • 对于 10k 的训练集,较高的 p 值帮助不大。
  • 而对于 140k 的情况明显不同:所有增强都是有害的。 基于这些结果,我们选择在其余测试中仅使用像素块传输、几何和颜色变换。
  • 图 4d 显示,虽然更强的增强减少了过度拟合,但它们也减慢了收敛速度。

在实践中,对数据集大小的敏感性要求进行昂贵的网格搜索,即便如此,依赖任何固定的 p 可能不是最佳选择。 接下来,我们通过使流程自适应来解决这些问题。

3. 自适应判别器增强

理想情况下,我们希望避免手动调整增强强度,而是根据过度拟合的程度动态控制它。 图 1 为此提出了一些可能的方法。 量化过度拟合的标准方法是使用单独的验证集并观察其相对于训练集的行为。 从图中我们看到,当过度拟合开始时,验证集开始表现得越来越像生成的图像。 这是一个可量化的效果,尽管在训练数据可能已经供不应求时需要单独的验证集。 我们还可以看到,随着 StyleGAN2 使用的非饱和损失,随着情况变得更糟,真实图像和生成图像的鉴别器输出在零附近对称地发散。 这种差异可以在没有单独的验证集的情况下进行量化。

让我们分别用 D_train、D_validation 和 D_generated 表示训练集、验证集和生成图像的鉴别器输出,以及它们在 N 个连续 minibatch 上的均值用 E[·] 表示。 在实践中,我们使用 N = 4,对应于 4×64 = 256 张图像。 我们现在可以将我们对图 1 的观察转化为两个可行的过拟合启发式方法:

对于这两种启发式算法,r = 0 表示没有过度拟合,r = 1 表示完全过度拟合,我们的目标是调整增强概率 p,使所选启发式算法匹配合适的目标值。 第一个启发式 r_v 表示相对于训练集和生成图像的验证集的输出。 由于它假定存在单独的验证集,因此我们主要将其作为一种比较方法。 第二个启发式 r_t 估计训练集中获得正判别器输出的部分。 我们发现,与直接查看 E[D_train] 相比,这对所选目标值和其他超参数的敏感度要低得多。

我们控制增强强度 p 如下。 我们将 p 初始化为零,并根据所选的过度拟合启发式每四个 minibatch 调整一次其值。 如果启发式方法表明过度拟合太多/太少,我们通过将 p 递增/递减固定量来进行反击。 我们设置调整大小,以便 p 可以足够快地从 0 上升到 1。 在每一步之后,我们将 p 从钳制到 0。我们将这种变体称为自适应鉴别器增强 (adaptive discriminator augmentation,ADA)。

在图 5a、b 中,我们测量目标值如何影响使用这些启发式方法可获得的质量。 我们观察到 r_v 和 r_t 都可以有效防止过度拟合,并且它们相比于比使用网格搜索找到的最佳固定 p 改进了结果。 我们选择在所有后续测试中使用更现实的 r_t 启发式,以 0.6 作为目标值。

图 5c 显示了随时间变化的 p。使用 2k 训练集,几乎总是在最后应用增强。 这超过了实际安全限制,之后一些增强会泄漏,表明增强不够强大。 事实上,在这种极端情况下,FID 在 p 约等于 0.5 后开始恶化。

图 5d 显示了具有自适应 vs 固定 p 的 rt 的演变,表明固定 p 往往在开始时太强而在结束时太弱。 

图 6 使用 ADA 重复图 1 中的设置。 无论训练集大小如何,现在都可以实现收敛,并且不再发生过度拟合。 如果没有增强,生成器从鉴别器接收到的梯度会随着时间的推移变得非常简单——鉴别器开始只关注少数几个特征,而生成器可以自由地创建其他无意义的图像。 使用 ADA,梯度场保持更详细,从而防止这种恶化。 在一个有趣的并行研究中,已经表明,通过使用类似的图像增强集合,可以使损失函数在回归设置中更加稳健。 

4. 评估

4.1 从零开始训练

图 7 显示了我们在 FFHQ 和 LSUN CAT 中跨训练集大小的结果,表明我们的自适应鉴别器增强 (ADA) 在有限的数据场景中显着提高了 FID。 我们还展示了平衡一致性正则化 (bCR) 的结果,之前尚未在有限数据的背景下对其进行研究。 我们发现,当数据缺乏不太严重时,bCR 可以非常有效,而且它的一组增强会泄漏到生成的图像中。 在此示例中,我们仅使用 bCR 的整数偏移量的 xy 平移,图 7d 显示生成的图像因此变得模糊。 这意味着 bCR 本质上是一种数据集增强,需要限制在实际有利于训练数据的对称性上,例如,x-flip 通常是可以接受的,但 y-flip 很少。 同时,使用 ADA,增强不会泄漏,因此可以在所有数据集中安全地使用相同的多样增强集。 我们还发现 ADA 和 bCR 的好处在很大程度上是相加的。 我们将 ADA 和 bCR 结合起来,以便首先将 ADA 应用于输入图像(真实的或生成的),然后 bCR 使用自己的一组增强创建该图像的另一个版本。 定性结果见附录 A。

在图 8a 中,我们进一步将我们的自适应增强与更广泛的替代方案进行比较:PA-GAN、WGAN-GP、zCR、辅助旋转和谱归一化。 我们还尝试修改我们的基线以使用更浅的映射网络,它可以用更少的数据进行训练,借鉴 DeLiGAN 的直觉。 最后,我们尝试用乘法 dropout 替换我们的增强,其每层强度由我们的自适应算法驱动。 我们花费了大量精力调整所有这些方法的参数,请参阅附录 D。我们可以看到 ADA 给出的结果明显好于其他方法。 虽然 PA-GAN 与我们的方法有些相似,但它的校验和任务不够强大,无法防止我们测试中的过度拟合。 图 8b 显示降低鉴别器容量通常是有害的并且不能防止过度拟合。

4.2 迁移学习

迁移学习通过从使用其他数据集训练的模型开始而不是随机初始化来减少训练数据要求。 几位作者在 GAN 的背景下对此进行了探索,Mo 等人最近通过在迁移过程中冻结鉴别器的最高分辨率层 (Freeze-D) 显示出很好的结果。 

我们探索了图 9 中的几种迁移学习设置,使用通过网格搜索为每个案例找到的最佳 Freeze-D 配置。 迁移学习的结果明显好于从头开始的训练,它的成功似乎主要取决于源数据集的多样性,而不是受试者之间的相似性。 例如,FFHQ(人脸)可以从 CELEBA-HQ(人脸,低多样性)或 LSUN DOG(更多样化)得到同样好的训练。 然而,LSUN CAT 只能从具有相当多样性的 LSUN DOG 进行训练,而不能从多样性较低的数据集中进行训练。 对于较小的目标数据集大小,我们的基线可以快速实现合理的 FID,但随着训练的继续,进展很快就会恢复。 ADA 再次能够几乎完全防止发散。 当与 ADA 一起使用时,Freeze-D 提供了一个小而可靠的改进,但不能单独防止发散。 

4.3 小数据集

我们用几个包含有限数量训练图像的数据集尝试了我们的方法(图 10)。 METFACES 是我们从大都会艺术博物馆的藏品中提取的 1336 张高质量人脸的新数据集。 BRECAHAD 仅包含 162 张乳腺癌组织病理学图像 (1360x1024); 我们将它们重组为 1944 个部分重叠的 512x512 裁剪图。动物面孔 (AFHQ) 包括狗、猫和野生动物的每个类别 5k 特写; 我们将它们视为三个独立的数据集,并为每个数据集训练了一个单独的网络。 CIFAR-10 包括 10 个类别的 50k 个小图像。

图 11 显示 FID 不是小型数据集的理想指标,因为当真实图像数量不足时,它会受固有偏差的支配。 我们发现 kernel inception distance (KID)——在设计上是无偏的——在实践中更具描述性,并且发现 ADA 比基线 StyleGAN2 有了显着的改进。 从头开始训练时尤其如此,但迁移学习也受益于 ADA。 在广泛使用的 CIFAR-10 基准测试中,我们在类条件设置中将 SOTA FID 从 5.59 提高到 2.42,并将起始分数 (IS) 从 9.58 提高到 10.24(图 11b)。 这一重大改进将 CIFAR-10 描绘成一个有限的数据基准。 我们还注意到,特定于 CIFAR 的架构调整具有显着效果。

5. 结论

我们已经证明,我们的自适应鉴别器增强能够可靠地稳定训练,并在训练数据供不应求时极大地提高结果质量。 当然,数据增强并不能替代真实数据——人们应该总是首先尝试收集大量高质量的训练数据,然后才使用数据增强来填补空白。 作为未来的工作,值得寻找最有效的增强集,看看最近发布的技术,如 U-net 鉴别器或多模态生成器,也可以援助有限的数据。 

启用 ADA 对训练单个模型的能耗影响可以忽略不计。 因此,使用它不会增加实际使用训练模型或开发需要大规模探索的方法的成本。 作为参考,附录 E 提供了我们执行的与本文相关的所有计算的细目分类; 该项目总共消耗了 325 兆瓦时的电力,或 135 个单 GPU 年,其中大部分可归因于广泛的比较和扫描。

有趣的是,鉴别器增强的核心思想是由其他三个研究小组在并行工作中独立发现的:Z. Zhao 等人、Tran 等人和 S. Zhao 等人。 我们推荐这些论文,因为它们都提供了一组不同的直觉、实验和理论依据。 Z. Zhao 和 S. Zhao 提出了与我们所做的基本相同的增强机制,但它们仅凭经验研究了泄漏伪影的缺失。 Tran 提出了基于可逆性的理论依据,但得出了不同的论点,导致更复杂的网络架构,以及对可能的增强集的显着限制。 这些作品都没有考虑自适应调整增强强度的可能性。 我们在第 3 节中的实验表明,最佳增强强度不仅在不同内容和大小的数据集之间变化,而且在训练过程中也会变化——即使是一组最佳的固定增强参数也可能会影响性能。

很难直接比较平行作品之间的结果,因为所有论文中使用的唯一数据集是 CIFAR-10。 遗憾的是,其他三篇论文使用 10k 生成图像和 10k 验证图像 (FID-10k) 计算 FID,而我们使用遵循 Heusel 等人的原始建议,并使用 50k 生成图像和所有训练图像。 因此,它们的 FID-10k 数字无法与图 11b 中的 FID 相比较。 出于这个原因,我们还为我们的方法计算了 FID-10k,获得无条件的 7.01 和有条件的 6.54。 这些与并行工作的无条件 9.89 或 10.89 以及有条件的 8.30 或 8.49 相比是有利的。 似乎有可能将所有四篇论文的想法进行某种组合可以进一步改善我们的结果。 例如,更多样化的增强集或对比正则化可能值得测试。

6. 更广泛的影响

数据驱动的生成建模意味着学习一种计算方法,以完全基于示例生成复杂的数据。 这是机器学习中的一个基础问题。 除了它们的基本性质外,生成模型在应用机器学习研究中还有多种用途,如先验、正则化等。 在这些角色中,他们提高了计算机视觉和图形算法的能力,以分析和合成逼真的图像。

这项工作中提出的方法可以使用比现有方法所需的数据少得多的数据来训练高质量的生成图像模型。 因此,它主要解决了一个深层次的技术问题,即有多少数据足以让生成模型成功地获取数据中必要的共性和关系。

从应用的角度来看,这项工作有助于提高效率; 它没有引入基本的新功能。 因此,在关于计算机视觉和图形的更广泛影响的积极讨论中,这里的进展似乎不会对总体主题——监视、真实性、隐私等——产生实质性影响。

具体来说,生成模型对图像和视频真实性的影响是一个活跃讨论的话题。 大多数注意力都围绕着允许语义控制和有时操纵现有图像的条件模型。 我们的算法不提供对生成图像中高级属性(例如,身份、姿势、人的表情)的直接控制,也不支持对现有图像的直接修改。 然而,随着时间的推移和通过其他研究人员的工作,我们的进步也可能会导致这些类型的模型得到改进。

这项工作的贡献使得使用自定义图像集训练高质量的生成模型变得更加容易。 通过这种方式,我们消除或至少显着降低了在许多应用研究领域应用 GAN 类型模型的障碍。 我们希望并相信这将加速几个此类领域的进展。 例如,对生物标本(组织、肿瘤等)可能出现的空间进行建模是一个不断发展的研究领域,似乎长期受到高质量数据有限的影响。 总的来说,生成模型有望增加对许多现实世界现象中复杂且难以确定的关系的理解; 我们的工作有望扩大可研究现象的广度。

参考

Karras T, Aittala M, Hellsten J, et al. Training generative adversarial networks with limited data[J]. Advances in neural information processing systems, 2020, 33: 12104-12114.

S. 总结

S.1 核心思想

为了解决用于训练 GAN 的数据少的问题,作者提出了一种新的数据增强方式:自适应鉴别器增强(adaptive discriminator augmentation,ADA)。相比于现有的数据增强方式,该方法研究如何避免把增强泄漏到生成数据中。

作者发现,当增强的数据出现的概率 p < 1,即有 1-p 的概率用于训练的是未增强的数据时,有可能避免把增强泄漏到生成数据中。对于不同的数据集大小,这个 p 值不是固定的,所以 ADA 就是基于训练,自适应的找出合适的 p。

S.2 方法

分别用 D_train、D_validation 和 D_generated 表示训练集、验证集和生成图像的鉴别器输出,以及它们在 N 个连续 minibatch 上的均值用 E[·] 表示,则有

对于 r_v 和 r_t,r = 0 表示没有过拟合,r = 1 表示完全过拟合。控制方式如下:将 p 初始化为零,并根据过拟合程度每四个 minibatch 调整一次 p 值。 如果过拟合太多/太少,将 p 递增/递减固定量。我们设置调整大小,以便 p 可以足够快地从 0 上升到 1。上升到 1 之后,把 p 置为 0。

S.3 其他贡献

作者对真实数据、生成数据以及生成器使用数据增强,这要求这些增强对生成器和鉴别器是可鉴别的。另一篇论文,可鉴别数据增强(Differentiable Augmentation,DA)就是对本贡献的进一步研究(消融实验)。

以某个概率 p 执行数据增强可以理解为:增强模块生效的概率是 p。

  • 当 p=0 时,训练时未使用数据增强。不能解决数据有限的问题。
  • 当 p=1 时,完全使用增强的数据进行训练。模型未见过真实数据,从而使增强泄漏到生成数据中,即生成器学习到了(可能只是一部分)增强的数据分布。(DA 的研究结果与此矛盾,存疑) 

S.4 相关内容分析

“随机化是针对每个增强和 minibatch.中的每个图像单独完成的。 鉴于流程中有许多增强,即使 p 值相当小,鉴别器也不太可能看到完全干净的图像。 尽管如此,只要 p 保持在实际安全门限以下,生成器就会被引导只生成干净的图像。”

这是本文中的一段内容,DA 一文对可以生成干净图像做出了解释:使用可鉴别增强使生成器学到的是真实数据的分布,而不是增强数据的分布。

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用WGAN-GP训练mnist数据集的步骤: 1.导入必要的库和模块 ```python import tensorflow as tf from tensorflow.keras import layers import numpy as np import matplotlib.pyplot as plt ``` 2.加载mnist数据集 ```python (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data() train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32') train_images = (train_images - 127.5) / 127.5 # 将像素值归一化到[-1, 1]之间 BUFFER_SIZE = 60000 BATCH_SIZE = 256 train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE) ``` 3.定义生成器和判别器模型 ```python def make_generator_model(): model = tf.keras.Sequential() model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,))) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Reshape((7, 7, 256))) assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制 model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False)) assert model.output_shape == (None, 7, 7, 128) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False)) assert model.output_shape == (None, 14, 14, 64) model.add(layers.BatchNormalization()) model.add(layers.LeakyReLU()) model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh')) assert model.output_shape == (None, 28, 28, 1) return model def make_discriminator_model(): model = tf.keras.Sequential() model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1])) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')) model.add(layers.LeakyReLU()) model.add(layers.Dropout(0.3)) model.add(layers.Flatten()) model.add(layers.Dense(1)) return model ``` 4.定义损失函数和优化器 ```python generator_optimizer = tf.keras.optimizers.Adam(1e-4) discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) def discriminator_loss(real_output, fake_output): real_loss = tf.reduce_mean(real_output) fake_loss = tf.reduce_mean(fake_output) return fake_loss - real_loss def generator_loss(fake_output): return -tf.reduce_mean(fake_output) ``` 5.定义训练函数 ```python @tf.function def train_step(images): noise = tf.random.normal([BATCH_SIZE, 100]) with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: generated_images = generator(noise, training=True) real_output = discriminator(images, training=True) fake_output = discriminator(generated_images, training=True) gen_loss = generator_loss(fake_output) disc_loss = discriminator_loss(real_output, fake_output) gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables) gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables) generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) ``` 6.训练模型 ```python EPOCHS = 100 noise_dim = 100 num_examples_to_generate = 16 # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度) seed = tf.random.normal([num_examples_to_generate, noise_dim]) generator = make_generator_model() discriminator = make_discriminator_model() for epoch in range(EPOCHS): for image_batch in train_dataset: train_step(image_batch) # 每 15 个 epoch 生成一次图片 if epoch % 15 == 0: generate_and_save_images(generator, epoch + 1, seed) # 生成最终的图片 generate_and_save_images(generator, EPOCHS, seed) ``` 7.生成图片 ```python def generate_and_save_images(model, epoch, test_input): # 注意 training` 设定为 False # 因此,所有层都在推理模式下运行(batchnorm)。 predictions = model(test_input, training=False) fig = plt.figure(figsize=(4, 4)) for i in range(predictions.shape[0]): plt.subplot(4, 4, i+1) plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray') plt.axis('off') plt.savefig('image_at_epoch_{:04d}.png'.format(epoch)) plt.show() ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值