原文地址:cycleGAN模型构建及代码解读及细节_HNU_刘yuan的博客-CSDN博客_cyclegan代码
cycleGAN简介
cycleGAN是一种由Generative Adversarial Networks发展而来的一种无监督机器学习,是在pix2pix的基础上发展起来的,主要应用于非配对图片的图像生成和转换,可以实现风格的转换,比如把照片转换为油画风格,或者把照片的橘子转换为苹果、马与斑马之间的转换等。因为不需要成对的数据集就能够转换,所以在数据准备上会简单很多,十分具有应用前景。
cycleGAN名字中之所以有一个cycle,我觉得应该是原图经过一种生成网络转换后得到另一种风格的图片,然后还要经过另一种生成网络转换后尽可能的接近原图,形成了一个循环,所以被称为cycleGAN。
所以有:
AtoBtoA = G_BtoA(G_AtoB(real_A)) 从A风格转换到B风格,又转换为A风格
BtoAtoB = G_AtoB(G_BtoA(real_B)) 从B风格转换到A风格,又转换为B风格
cycleGAN中的网络
cycleGAN由两个生成网络和两个判别网络构成
- G_AtoB() 看作是风格A向风格B的生成网络
- G_BtoA() 看作是风格B向风格A的生成网络
- dis_A() 看作是判别输入图片是否属于风格A的判别网络
- dis_B() 看作是判别输入图片是否属于风格B的判别网络
- AtoB = G_AtoB(real_A) 看作是real_A经过生成网络转换得到的风格B的照片
- BtoA = G_BtoA(real_B) 看作是real_B经过生成网络转换得到的风格A的照片
其中G_AtoB()和G_BtoA()的输入为[B, C, W, H],即batchsize, channels, width, height,输出一般与输入相同;
其中dis_A()和dis_B()的输入为[B, C, W, H],即batchsize, channels, width, height,输出的维度是[B, 1],里面的是经过sigmoid函数输出的,所以取值范围在[0, 1]进行分类。
生成器由三个部分组成:
- 编码器(由三层卷积网络构成,并进行归一化;使用了残差块,减弱梯度消失,使网络可以自己自适应地调节层数的深浅,变得更深的同时更平滑)
- 转换器
- 解码器(用到反卷积(逆卷积)和卷积层,经过残差结构,第一、二层反卷积,第三层卷积)
辨别器:用的是5层卷积,将通道数减为1,最后进行池化平均,再reshape成[batchsize 1]
损失函数:
cycleGAN中用到了两种损失函数,
- MSE,应用在标签中,用来判断discriminator输出的label和真实lable之间的loss。
gen_AtoB中,Dis_B判断AtoB生成的图片与真实标签之间的loss
gen_BtoA中,Dis_A判断BtoA生成的图片与真实标签之间的loss
Dis_A中 real_A与真实标签之间的loss | | B2A与虚假标签之间的loss
Dis_B中 real_B与真实标签之间的loss | | A2B与虚假标签之间的loss - L1,应用在图片中,衡量图片与图片之间的loss。 real_A和 A2B2A之间
real_B和 B2A2B之间
real_A和 B2A(real_A)
real_B和 A2B(real_B)
其中的第三种和第四种情况可以理解为:经过生成该图片风格的生成器生成的图片应该尽量与原图保持一致。也被成为identity loss。可以理解成生成器Gen_AtoB负责x域(domain)到y域图像的生成,如果输入y域的图片,输出仍然是y域的图片,比较符合直觉,用的是L1函数。