背景:
Cycle-GAN是一个2017年推出的直击产业痛点的模型。众所周知,在一系列视觉问题上是很难以找到匹配的高质量图像作为target来供模型学习的,比如在超分辨领域内对于一个低分辨率的物体图像,未必能找到同样场景的高分辨率图像,这使得一系列深度学习模型的适应性有限。上述的困难总结起来就是:由于模型训练时必须依赖匹配的图像,而除非有目的的去产生这样的图像否则无法训练,并且很容易造成数据有偏。
Cycle-GAN训练的目的则避开了上述困难;该模型的思路是旨在形成一个由数据域A到数据域B的普适性映射,学习的目标是数据域A和B的风格之间的变换而非具体的数据a和b之间的一一映射关系。从这样的思路出发Cycle-GAN对于数据一一匹配的依赖性就不存在了,可以解决一系列问题,因此该模型的设计思路与具体做法十分值得学习。
总的来说,基于Cycle-GAN的模型具有较强的适应性,能够适应一系列的视觉问题场合,比如超分辨,风格变换,图像增强等等场合。
下面附一张匹配和非匹配图像的说明
设计思路:
通常的GAN的设计思路从信息流的角度出发是单向的,如下图所示:使用Generator从a产生一个假的b,然后使用Determinator判断这个假的b是否属于B集合,并将这个信息反馈至Generator,然后通过逐次分别提高Generator与Discriminator的能力以期达到使Generator能以假乱真的能力,这样的设计思路在一般有匹配图像的情况下是合理的。
而在Cycle-GAN为了能不依赖一一对应的图像,需要确保学习到的映射不能将不同的a1、a2、...映射为同一个b,尽管这个b确实可以达到以假乱真的结果。为了防止这种情况Cycle-GAN引入了cycle概念,简单来说就是将假的b映射回A,产生一个假的a1,并判断这假的a1是否同真的a1近似,用这样的方式来确保模型真正能学习到一对一的映射而不是一对多。这个思路来自于机器翻译邻域[1], 具体说来如下图所示:
在模型中定义两个变换G、F,分别表示从输入域X到目标域Y的变换和逆变换,并且使用Dx和Dy来分辨判断F和G的效果,在使用Dx和Dy的同时,引入两个cycle过程,第一个cycle过程(图中b),是指使用真实的x产生一个估计的y,再使用这个估计的y进行逆变换产生一个估计的x,此时评价真实的x和这个估计的x的差别,同样的Y到X的过程也引入这样的cycle过程(图中c)
整个网络过程可以如下所示:来自于(https://www.jianshu.com/p/64bf39804c80),这里GeneratorAtoB和GeneratorBtoA以及discriminatorA和discriminatorB的网络结构两两相同[2]
loss函数:
loss函数的组成最能反映这种设计思路,Cycle-GAN的loss总体来说可以分为两部分,一部分是GAN loss,一部分是Cycle loss
具体的表现形式是
上述是作者论文中提出的初始版本,在代码里面使用的是MSELoss[3]
训练注意事项:
为了保证训练出的模型具有较强的稳定性,使用了两种技术:第一,基于前人的成果将GAN loss变为非负的似然计算改为最小平方loss[4], 第二,使用了GAN训练的“记忆”技术,即在更新Discriminator时使用先期存储的数据而非刚刚由Generator产生数据[5]。 另外具体的训练细节可以参考下一篇博客 对于Cycle-GAN的代码介绍
参考资料:
[1]. R. W. Brislin. Back-translation for cross-cultural research. Journal of cross-cultural psychology, 1(3):185–216, 1970.
[2]. https://github.com/vanhuyz/CycleGAN-TensorFlow.git
[3]. https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
[4]. Multiclass generative adversarial networks with the l2 loss function. arXiv preprint arXiv:1611.04076, 2016.
[5]. A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, and R. Webb. Learning from simulated and unsupervised images through adversarial training. arXiv preprint arXiv:1612.07828, 2016.