昇思25天打卡训练营第19天|CycleGAN图像风格迁移互换

模型介绍
模型简介
CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构
CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

为了方便理解,这里以苹果和橘子为例介绍。上图中  𝑋可以理解为苹果, 𝑌为橘子; 𝐺为将苹果生成橘子风格的生成器, 𝐹为将橘子生成的苹果风格的生成器, 𝐷𝑋和  𝐷𝑌为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

图中苹果图片  𝑥经过生成器  𝐺得到伪橘子  𝑌̂ ,然后将伪橘子  𝑌̂ 结果送进生成器  𝐹又产生苹果风格的结果  𝑥̂ ,最后将生成的苹果风格结果  𝑥̂ 与原苹果图片  𝑥一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载
使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

数据集加载
使用 MindSpore 的 MindDataset 接口读取和解析数据集。

可视化
通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

构建生成器
本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

具体的模型结构请参照下文代码:

构建判别器
判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

优化器和损失函数
根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器  𝐺及其判别器  𝐷𝑌,目标损失函数定义为:

其中  𝐺试图生成看起来与  𝑌中的图像相似的图像  𝐺(𝑥),而  𝐷𝑌的目标是区分翻译样本  𝐺(𝑥)和真实样本  𝑦,生成器的目标是最小化这个损失函数以此来对抗判别器。即  𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌)  。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于  𝑋的每个图像  𝑥,图像转换周期应能够将  𝑥带回原始图像,可以称之为正向循环一致性,即  𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于  𝑌,类似的  𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

循环一致损失能够保证重建图像  𝐹(𝐺(𝑥))  与输入图像  𝑥紧密匹配。

前向计算
搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

计算梯度和反向传播
其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

模型训练
训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化  𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2]

训练生成器:如 CycleGAN 论文所述,我们希望通过最小化  𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

模型推理
下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

参考
[1] I. Goodfellow. NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016. 2, 4, 5

[2] A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017. 3, 5, 6, 7

  • 14
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值