CycleGAN的原理可以概述为:
将一类图片转换成另一类图片 。也就是说,现在有两个样
本空间,X和Y,我们希望把X空间中的样本转换成Y空间中
的样本。(获取一个数据集的特征,并转化成另一个数据
集的特征)
这样来看:实际的目标就是学习从X到Y的映射。我们设这
个映射为F。它就对应着GAN中的 生成器 ,F可以将X中的
图片x转换为Y中的图片F(x)。对于生成的图片,我们还需要
GAN中的 判别器 来判别它是否为真实图片,由此构成对抗
生成网络
在足够大的样本容量下,网络可以将相同的输入图像集合
映射到目标域中图像的任何随机排列,其中任何学习的映
射可以归纳出与目标分布匹配的输出分布(即:映射F完全
可以将所有x都映射为Y空间中的同一张图片,使损失无效
化)。
因此,单独的对抗损失Loss不能保证学习函数可以
将单个输入Xi映射到期望的输出Yi。
对此,作者又提出了所谓的“循环一致性损失”
(cycle consistency loss)。
我们希望能够把 domain A 的图片(命名为 a)转
化为 domain B 的图片(命名为图片 b)。
为了实现这个过程,我们需要两个生成器 G_AB 和
G_BA,分别把 domain A 和 domain B 的图片进行
互相转换。
将X的图片转换到Y空间后,应该还可以转换回来。
这样就杜绝模型把所有X的图片都转换为Y空间中的
同一张图片了
最后为了训练这个单向 GAN 需要两个 loss,分别是
生成器的重建 loss 和判别器的判别 loss。
判别 loss:判别器 D_B 是用来判断输入的图片是否
是真实的 domain B 图片
CycleGAN 其实就是一个 A→B 单向 GAN 加上一个
B→A 单向 GAN。两个 GAN 共享两个生成器,然
后各自带一个判别器,所以加起来总共有两个判别器
和两个生成器。
一个单向 GAN 有两个 loss,而 CycleGAN 加起来
总共有四个 loss。
对颜色、纹理等的转换效果比较好,对多样性高的、
多变的转换效果不好(如几何转换)
代码
import tensorflow as tf
import glob
from matplotlib import pyplot as plt
%matplotlib inline
AUTOTUNE = tf.data.experimental.AUTOTUNE
import os
os.listdir('../input/apple2orange/apple2orange')
imgs_A = glob.glob('../input/apple2orange/apple2orange/trainA/*.jpg')
imgs_B = glob.glob('../input/apple2orange/apple2orange/trainB/*.jpg')
test_A = glob.glob('../input/apple2orange/apple2orange/testA/*.jpg')
test_B = glob.glob('../input/apple2orange/apple2orange/testB/*.jpg')
def read_jpg(path):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
return img
def normalize(input_image):
input_image = tf.cast(input_image, tf.float32)/127.5 - 1
return input_image
def load_image(image_path):
image = read_jpg(image_path)
image = tf.image.resize(image, (256, 256))
image = normalize(image)
return image
train_a = tf.data.Dataset.from_tensor_slices(imgs_A)
train_b = tf.data.Dataset.from_tensor_slices(imgs_B)
test_a = tf.data.Dataset.from_tensor_slices(test_A)
test_b = tf.data.Dataset.from_tensor_slices(test_B)
BUFFER_SIZE = 200
train_a = train_a.map(load_image,
num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
train_b = train_b.map(load_image,
num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_a = test_a.map(load_image,
num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_b = t