CycleGAN 与非配对图像转换

本文介绍CycleGAN原理以及在tensorflow中实现。

 

一、CycleGAN 的原理

cGAN 和对应的 pix2pix 模型,都能够解决一类“图像翻译 ”问题 。 但是 pix2pix 模型要求训练样本必须是“严格成对”的,这种样本往往比较难以获得,CycleGAN 不必使用成对样本也可以进行“图像翻译”。CycleGAN与 pix2pix的不同点在于,它可以利用不成对数据训练出从 X 空间到 Y 空间的映射 。 例如,只要搜集了大量照片以及大量油画图片,可以学习到如何把照片转换成油画。

CycleGAN 的原理:算法的目标是学习从空间 X 到空间 Y 的映射,设这个映射为 F。 它对应着 GAN 中的生成器, F 可以将 X 中的图片 x 转化为 Y 中的图片 F(x)。对于生成的图片,还需要 GAN 中的判别器来判别器是否为真实图片,由此构成对抗生成网络 。但由于没有成对数据,这个网络是无法训练的。对此,作者又提出了所谓的“循环一致性损失”( cycle consistency loss )。让再假设一个映射 G,它可以将 Y 空间中的图片 y 转换为 X 中的图片 G(y)。 CycleGAN,同时学习 F 和 G 两个映射,并要求 F(G(y)) = y,以及 G(F(x)) = x。也就是说,将 x 的图片转换到 Y 空间后,应该还可以转换回来。

循环一致性损失定义:

总损失定义:

CycleGAN 的主要想法是上述的“循环一致性损失”,利用这个损失 3 可以巧妙地处理 X 空间和 Y 空间训练样本不一一配对的问题。

 

二、在 TensorFlow 中用训练 CycleGAN 模型

1、下载数据集并进行训练

(1)下载数据集

apple2orange数据集包含了苹果和橘子的图像,运行命令下载数据集:bash download_dataset.sh apple2orange,运行报错:wget: command not found,显然是因为没有安装wget导致的,wget用英语定义就是the non-interactive network downloader,翻译过来就是非交互的网络下载器。这里使用homebrew安装wget,Homebrew为macOS提供缺失的软件包管理器,使用Homebrew可以安装Apple没有预装但你需要的东西,Homebrew会将软件包安装到独立目录,并将其文件软链接至 /usr/local。Homebrew 不会将文件安装到它本身目录之外,所以可将 Homebrew 安装到任意位置。先安装homebrew,命令为:/usr/bin/ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)",安装成功后,利用homwbrew安装wget,命令为:brew install wget,安装成功。download_dataset.sh文件内容如下:

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./data/$FILE.zip
TARGET_DIR=./data/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir -p $TARGET_DIR
unzip $ZIP_FILE -d ./data/
rm $ZIP_FILE

运行报错:时间戳与 -O 结合使用没有任何效果,去掉wget的 -N参数即可。-N参数表示只获取比本地新的文件,-O参数表示将文档写入$ZIP_FILE中。至此,运行下载数据集命令不报错,数据集下载完成,生成 data/apple2orange 目录。其中 trainA、 testA 中保存了苹果的图像, trainB、 testB 中保存了橙子的图像,如图:

(2)转换图片格式

由于该项目使用 tfrecords 读取数据,再将图片转换为tfrecords格式(大数据文件格式),命令为:

python build_data.py 
--X_input_dir data/apple2orange/trainA 
--Y_input_dir data/apple2orange/trainB 
--X_output_file data/tfrecords/apple.tfrecords 
--Y_output_file data/tfrecords/orange.tfrecords

运行报错 except os.error, e: SyntaxError: invalid syntax,这是python2的捕获方法,在python3中为except Exception as e,因此代码改为:

try:
  os.makedirs(output_dir)
except os.error as e:
  pass

至此,数据格式转换成功。

(3)训练模型

运行训练模型的命令:

python train.py 
--X data/tfrecords/apple.tfrecords 
--Y data/tfrecords/orange.tfrecords 
--image_size 256

运行报错:absl.flags._exceptions.IllegalFlagValueError: flag --lambda1=10.0: Expect argument to be a string or int, found <class 'float'>,原因是需要int型,而传入float型,将 tf.flags.DEFINE_integer 改为 flags.DEFINE_float 即可。训练开始后,程序会在 checkpoints 文件夹中建立一个以当前时间命名的目录,如“checkpoints/20190624-1053”,训练时的曰志和模型都会保存在该文件夹中。ckpt为tensorflow的模型文件格式,其他几种格式参考:https://blog.csdn.net/sinat_31337047/article/details/81483006

此外,每隔 100 步,程序还会在屏幕上打出当前步数和损失, 可以通过它们来监控模型的训练。

更方便的做法是使用 TensorBoard,即运行 : tensorboard --logdir checkpoints/20190624-1053/,运行命令报错:AttributeError: module 'tensorboard.util' has no attribute 'Retrier',原因是tensorboard与tensorflow版本不符合。使用pip list 查看tensorflow版本,我的是1.13.1,在https://github.com/tensorflow/tensorboard/releases?after=1.13.1找到相对应的tensorboard版本,即1.13.0,重新安装1.13.0的tensorboard即可:pip install tensorboard==1.13.0。

(4)用训练好的模型进行测试

将模型导出为 pb 文件,运行命令:

python export_graph.py 
--checkpoint_dir checkpoints/20190624-1146/ 
--XtoY_model apple2orange.pb 
--YtoX_model orange2apole.pb 
--image_size 256

运行命令使用模型pretrained/apple2orange.pb将图片data/apple2orange/testA/n07740461_1661.jpg进行转换,把生成的图片存放到data/apple2orange/output_sample.jpg中,如下:

python inference.py
--model pretrained/apple2orange.pb
--input data/apple2orange/testA/n07740461_1661.jpg
--output data/apple2orange/output_sample.jpg
--image_size 256

 

2、使用自己的数据进行训练

(1)准备两个文件夹, 一个文件夹中存放 X 空间内的图片,另一个文件夹中存放 Y 空间 中的文件 。使用数据集 man2woman.zip, 该数据集是一个人脸数据集,用 CycleGAN做一个实验:将男性变成女性以及将女性变成男性 。man2woman 数据集是从 CelebA 数据集中整理得到的,后者是一个大型的人脸数据集拥有 20 万张人脸图片。CelebA数据集下载地址https://pan.baidu.com/s/1eSNpdRG?errno=0&errmsg=Auth%20Login%20Sucess&&bduss=&ssnerror=0&traceid=#list/path=%2F&parentPath=%2F

(2)为了训练 CycleGAN,需要先将图片转换成 tfrecords 形式。运行命令后,得到了两个 tfrecords 文件 。

(3)直接利用这两个文件进行训练即可。训练的过程比较漫长,最好都打开 TensorBoard 观察训练的 Loss 和图像生成情况 。如果训练的过程发生了中断,可以不从头开始训练,指定--load_model 参数,可以从之前保存的模型中恢复并继续训练。

(4)使用训练好的模型就可以进行测试了。最终男生照片变成女性照片,女性照片变成男性照片。

 

三、程序结构分析

1、CycleGAN 模型定义(model.py)

def model(self):
  # 读入x空间数据
  X_reader = Reader(self.X_train_file, name='X',
      image_size=self.image_size, batch_size=self.batch_size)
  # 读入y空间数据
  Y_reader = Reader(self.Y_train_file, name='Y',
      image_size=self.image_size, batch_size=self.batch_size)

  # 将读入数据保存到x、y变量中
  x = X_reader.feed()
  y = Y_reader.feed()

  # 根据 self.G、self.F、x、y定义循环一致性损失 cycle_loss
  cycle_loss = self.cycle_consistency_loss(self.G, self.F, x, y)

  # X -> Y(self.G)
  fake_y = self.G(x)
  # 定义 self.G 生成图片的损失
  G_gan_loss = self.generator_loss(self.D_Y, fake_y, use_lsgan=self.use_lsgan)
  G_loss =  G_gan_loss + cycle_loss
  # 定义 Y 空间鉴别器的损失
  D_Y_loss = self.discriminator_loss(self.D_Y, y, self.fake_y, use_lsgan=self.use_lsgan)

  # Y -> X(self.F)
  fake_x = self.F(y)
  # 定义 self.F 生成图片的损失
  F_gan_loss = self.generator_loss(self.D_X, fake_x, use_lsgan=self.use_lsgan)
  F_loss = F_gan_loss + cycle_loss
  # 定义 X 空间鉴别器的损失
  D_X_loss = self.discriminator_loss(self.D_X, x, self.fake_x, use_lsgan=self.use_lsgan)

其中,self.F和self.G是生成器,D_X,D_Y是鉴别器

 

2、循环一致性损失定义(model.py)

def cycle_consistency_loss(self, G, F, x, y):
  # L1 损失
  forward_loss = tf.reduce_mean(tf.abs(F(G(x))-x))
  backward_loss = tf.reduce_mean(tf.abs(G(F(y))-y))
  loss = self.lambda1*forward_loss + self.lambda2*backward_loss
  return loss

 

3、生成器的损失(model.py)

def generator_loss(self, D, fake_y, use_lsgan=True):
  # use_lsgan指定了是否用LSGAN对应的损失函数。LSGAN是GAN的一种变体,损失函数略有不同。只关注use_lsgan=false的情况
  if use_lsgan:
    # 使用均方损失
    loss = tf.reduce_mean(tf.squared_difference(D(fake_y), REAL_LABEL))
  else:
    # D(fake_y)为生成器生成图像是真实图像的概率,D(fake_y)越大,说明生成器越好
    # 之所以加负号,是因为tensorflow的优化器都默认损失越小越好
    loss = -tf.reduce_mean(ops.safe_log(D(fake_y))) / 2
  return loss

 

4、鉴别器的损失(model.py)

def discriminator_loss(self, D, y, fake_y, use_lsgan=True):
  # 只关注use_lsgan=false
  if use_lsgan:
    # use mean squared error
    error_real = tf.reduce_mean(tf.squared_difference(D(y), REAL_LABEL))
    error_fake = tf.reduce_mean(tf.square(D(fake_y)))
  else:
    # y是真实数据,D(y)是判别器判断真实数据的对应概率,该值越大,说明判别器的性能越好,同样取负号
    error_real = -tf.reduce_mean(ops.safe_log(D(y)))
    # 再使用交叉摘损失并取负值得到error_fake
    error_fake = -tf.reduce_mean(ops.safe_log(1-D(fake_y)))
  # 总损失
  loss = (error_real + error_fake) / 2
  return loss

 

5、为4个损失定义优化操作

最终定义出4个损失: G一loss、 F_loss、 D_Y一loss、 D_X一loss。其中, G_loss和 F_loss是生成器损失,这两个损失降低则意昧着生成器的性能提高 。D_Y_loss 和 D X loss 是判别器 , 这两个损失的降低意昧着判别器性能提高。在优化时, 对四个损失同时优化即可。

def optimize(self, G_loss, D_Y_loss, F_loss, D_X_loss):
  # 对四个损失定义优化操作
  G_optimizer = make_optimizer(G_loss, self.G.variables, name='Adam_G')
  D_Y_optimizer = make_optimizer(D_Y_loss, self.D_Y.variables, name='Adam_D_Y')
  F_optimizer =  make_optimizer(F_loss, self.F.variables, name='Adam_F')
  D_X_optimizer = make_optimizer(D_X_loss, self.D_X.variables, name='Adam_D_X')
  # tf.no_op()将优化操作保存,直接调用optimizers即可完成对四个损失的优化
  with tf.control_dependencies([G_optimizer, D_Y_optimizer, F_optimizer, D_X_optimizer]):
    return tf.no_op(name='optimizers')
# 优化器定义函数
def make_optimizer(loss, variables, name='Adam'):
  """ Adam optimizer with learning rate 0.0002 for the first 100k steps (~100 epochs)
      and a linearly decaying rate that goes to zero over the next 100k steps
  """
  global_step = tf.Variable(0, trainable=False)
  starter_learning_rate = self.learning_rate
  end_learning_rate = 0.0
  start_decay_step = 100000
  decay_steps = 100000
  beta1 = self.beta1
  learning_rate = (
      tf.where(
              tf.greater_equal(global_step, start_decay_step),
              tf.train.polynomial_decay(starter_learning_rate, global_step-start_decay_step,
                                        decay_steps, end_learning_rate,
                                        power=1.0),
              starter_learning_rate
      )

  )
  tf.summary.scalar('learning_rate/{}'.format(name), learning_rate)

  learning_step = (
      tf.train.AdamOptimizer(learning_rate, beta1=beta1, name=name)
              .minimize(loss, global_step=global_step, var_list=variables)
  )
  return learning_step

 

四、总结

本文首先介绍了CycleGAN的原理,接着在tensorflow中用CycleGAN训练了两个模型(苹果橘子转换,男性女性转换),最后介绍了模型和损失的定义细节。CycleGAN 不 需要成对数据就可以训练,具有较强的通用性,由此产生了大量有创意的应用,例如男女互换。

 

  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值