论文链接:
pytorch开源代码:
目录
1.应用场景
不属于强监督,图像之间无需完全匹配,可实现图像风格转换。
2.数据处理
数据处理包括:数据扩充和数据归一化
3.网络结构
包括生成器及判别器,生成器用来实现图像风格A与图像风格B之间的相互转换,生成器有两个分别为图像风格A->图像风格B的转换和图像风格B->图像风格A的转换;判别器也有两个分别判断两个生成器转换后图像与目标图像比较的真假。
3.1生成器
生成器网络结构相比判别器网络结构复杂,可参考开源代码中网络结构,也可以根据实际计算能力及任务复杂情况修改网络结构。
值得注意的有以下三点:
1.网络结构最后一层卷积通道数为生成图像通道数,若输入图像为rgb图像则为三通道数据,这时卷积核输出通道数为3;若是输入图像为单通道数据则输出通道数为1。
2.最后一层激活函数选用问题上,可以使用sigmoid或是tanh,sigmoid输出结果在0-1之间,所以对应的输入应该也归一化到0-1之间;tanh输出结果为(-1,1),所以对应的输入图像归一化为(-1,1),选择哪种激活函数与数据归一化方式有关。
3.生成器网络中的归一化方式:选用IN,而不是使用BN.
3.2判别器
判别器网络结构相比生成器较为简单,判别器的功能为实现真假图像的二分类,真为目标图像,假为生成器生成图像,即生成对抗网络中的对抗概念,同时向生成网络传递对抗损失,优化生成器参数。
同样,也可以根据计算资源能力和任务复杂难易程度,选择合适的判别器。
值得注意的有以下两点:
1. 判别器最后一层与普通二分类网络结构不一样,不再使用sigmoid等激活函数计算该类别是否为真的概率,而是直接去掉sigmoid层,具体原因可以查看wgan lsgan等GAN方法介绍。
2. 判别器中同样不再使用BN,而是采用IN.
4.损失函数
4.1生成器损失
共包括三种:判别器损失+cyclegan损失+idt损失(论文里叙述较少,但是源码中有这部分内容,实际对比发现确实有效果,具体原因还没有仔细研究,看损失函数的计算大概是保证风格转换后与原图保证一致,使用原图约束风格转换后生成图像)
四天搞懂生成对抗网络(四)——CycleGAN的绝妙设计:双向循环生成的结构 - 云+社区 - 腾讯云
这个链接下有针对这个问题的图像对比,“论文中提到,CycleGAN使用identity loss的目的是在迁移的过程中保持原色调”。
idt损失权重应该比cyclegan损失权重小,代码中应该是2倍关系。
ga_gan_loss = reduce_mean((d_a(g_a(img_rb)) - 1) ** 2)
ga_cyc_loss = reduce_mean(abs(img_rb- g_b(g_a(img_rb))))
ga_ide_loss = reduce_mean(abs(img_ra - g_a(img_ra)))
4.2判别器损失
可以使用wgan损失,lsgan损失,可以参考pytorch源码。
5.训练策略
学习率以0.002开始,训练100轮,然后逐渐递减至0,训练100轮(图像生成学习率大部分都比较小)
6.其他
该网络模型共有四个网络结构,分别为A->B生成器(1),B->A生成器(2),A->B判别器(3),B->A判别器(4),其中训练时四个网络结构均需要进行前向传递和反向传递,测试时仅需加载需要的生成器参数进行前向传递,无需加载判别器参数。因此在训练时占用的显存较大,在测试时相对较小。