码字不易!
如果觉得有用请点赞!
上文:让算法拥有想象力的cycleGAN(一)原理分析,阐述了cycleGAN的基本原理,本文继续记录自己的pytorch实现过程,并分析视觉结果和损失函数曲线,包含以下几个部分:
(1)结果分析
(2)生成器结构
(3)判别器结构
(4)损失函数选择
(5)优化器选择
(6)生成器训练过程
(7)判别器训练过程
网上的实现版本中对于两个生成器和两个判别器的更新有些不同,我选择的是先同时更新两个生成器,再分别更新两个判别器。
1、结果分析
自己在两个数据集上进行了实验,第一个数据集是:卡通人脸-欧美人脸,第二个数据集是:素描鞋-真实鞋。下面分两个数据集的结果进行介绍:
1.1 素描鞋-真实鞋
素描鞋-真实鞋这个数据集包含的样本比较多,训练网络的时候效果也比较好,可以说是自己近期复现论文比较成功的一次了。输入的素描鞋通过生成器得到的真实鞋如下:
从上图结果可以看出,输入一个素描运动鞋时输出是粉色的运动鞋,这个上色结果不是单单给与颜色,而且鞋跟和鞋面以及鞋带是不同颜色的,说明网络真的具有某种“想象”的能力。下图展示了输入真实鞋输出素描鞋的效果:
输入是皮靴时输出的素描皮靴也完全保留了其轮廓特征。
下图展示了更多的转换结果(四个为一组,每组左边是输入,右边是对应输出):
损失函数依旧震动的比较严重,说明生成器和判别器正在博弈:
1.2 卡通人脸-欧美人脸
在卡通人脸和欧美人脸之间的转换结果就不是很好,可能是自己的人脸数据集较小,而且特征更加复杂。下图的视觉结果中左边是输入,右边是输出:
损失函数的收敛情况如下:
在人脸数据上的收敛要好很多,尽管人脸数据没有鞋数据的生成效果好。
2、生成器结构
生成器G和生成器F采用的都是首先使用步长大于1的卷积降低特征图的长度和宽度,然后使用残差连接,最后使用反卷积获得和输入图像尺寸相同的输出,这里采用的是StarGAN的生成器,前面GAN系列的文章已经给出了代码:
class
3、判别器结构
采用的依旧是StarGAN的判别器,输入图像,输出patch的判别结果,前面GAN系列的文章已经给出了代码:
class
4、损失函数选择
损失函数包含三种:对抗损失,循环一致损失和identity损失。对抗损失采用LSGAN的方式,所以是MSE Loss;循环一致损失按照论文采用L1 Loss;identity损失同样采用L1 Loss:
loss_function_GAN
5、优化器选择
优化时采用生成器G和生成器F同时进行,判别器DX和判别器DY分开进行的优化策略,所以需要三个optimizer。优化算法采用收敛性能较好的Adam优化器,其中beta1和beta2分别为0.5和0.999,训练过程中不进行学习率的动态调整:
optimizer_G
6、生成器训练过程
损失函数的计算是CycleGAN最核心也是最复杂的内容,生成器的损失计算分为三个过程:
(1)对域X和域Y计算identity损失
(2)生成器计算生成样本的对抗性损失
(3)计算循环一致损失,即重构损失
# 1:计算生成器损失
7、判别器训练过程
判别器DX和判别器DY的训练过程是分开的,二者的训练原理相同,这里仅以DX为例。判别器DX的对抗性损失包含对真实样本的损失和对伪造的损失两个部分,也就是要分别计算:
# 2:计算判别器X损失
8、总结
从代码实现的角度看,网络结构之类都不是很重要,最重要的是损失函数的计算。CycleGAN的损失计算比较绕,很容易写错生成样本和真实样本与两个生成器和两个判别器的输入输出关系,以及对抗性损失的更新次序。
参考:
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/tree/master/models
https://github.com/aitorzip/PyTorch-CycleGAN
https://github.com/znxlwm/pytorch-CycleGAN/blob/master/pytorch_cycleGAN.py#L144