CycleGAN算法原理(附源代码,可直接运行)

前言

CycleGAN是在今年三月底放在arxiv(论文地址CycleGAN)的一篇文章,文章名为Learning to Discover Cross-Domain Relations with Generative Adversarial Networks,同一时期还有两篇非常类似的DualGAN(论文地址:DualGAN)和DiscoGAN(论文地址:DiscoGAN),简单来说,它们的功能就是:自动将某一类图片转换成另外一类图片。不同于GAN和CGAN(上节已经介绍过),CycleGAN不需要配对的训练图像。当然了配对图像也完全可以,不过大多时候配对图像比较难获取。

这里写图片描述
这里写图片描述
配对图像
这里写图片描述
未配对的图像

CycleGAN能做什么?

CycleGAN可以完成GAN和CGAN的工作,如上述配对图像所示,可以从一个特定的场景模式图生成另外一个场景模式图,这两张场景模式中的物体完全相同。除此之外,CycleGAN还可以完成从一个模式到另外一个模式的转换,转换的过程中,物体发生了改变,比如下面的图像中从猫到狗,从男人到女人。

这里写图片描述
这里写图片描述

CycleGAN算法原理

如下图所示CycleGAN其实是由两个判别器( Dx D x Dy D y )和两个生成器(G和F)组成,但是为什么要连两个生成器和两个判别器呢?论文中说,是为了避免所有的X都被映射到同一个Y,比如所有男人的图像都映射到范冰冰的图像上,这显然不合理,所以为了避免这种情况,论文采用了两个生成器的方式,既能满足X->Y的映射,又能满足Y->X的映射,这一点其实就是变分自编码器VAE的思想,是为了适应不同输入图像产生不同输出图像。那么下面的四个公式也很清楚了,(1)是判别器Y对X->Y的映射G的损失,判别器X对Y->X映射的损失也非常类似(2)是两个生成器的循环损失,这里其实是 L1 L 1 损失(3)是总损失(4)是对总损失进行优化,先优化D然后优化G和F,这一点和GAN类似

这里写图片描述
这里写图片描述
(1)
这里写图片描述
(2)
这里写图片描述
(3)
这里写图片描述
(4)

源代码

训练源代码

import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePool

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 128, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,
                     'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance',
                       '[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10.0,
                        'weight for forward cycle loss (X->Y->X), default: 10.0')
tf.flags.DEFINE_integer('lambda2', 10.0,
                        'weight for backward cycle loss (Y->X->Y), default: 10.0')
tf.flags.DEFINE_float('learning_rate', 2e-4,
                      'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,
                      'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,
                      'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,
                        'number of gen filters in first conv layer, default: 64')

tf.flags.DEFINE_string('X', 'tfrecords/apple.tfrecords',
                       'X tfrecords file for training, default: tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'tfrecords/orange.tfrecords',
                       'Y tfrecords file for training, default: tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,
                        'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')


def train():
  if FLAGS.load_model is not None:
    checkpoints_dir = "checkpoints/" + FLAGS.load_model
  else:
    current_time = datetime.now().strftime("%Y%m%d-%H%M")
    checkpoints_dir = "checkpoints/{}".format(current_time)
    try:
      os.makedirs(checkpoints_dir)
    except os.error:
      pass

  graph = tf.Graph()
  with graph.as_default():
    cycle_gan = CycleGAN(
        X_train_file=FLAGS.X,
        Y_train_file=FLAGS.Y,
        batch_size=FLAGS.batch_size,
        image_size=FLAGS.image_size,
        use_lsgan=FLAGS.use_lsgan,
        norm=FLAGS.norm,
        lambda1=FLAGS.lambda1,
        lambda2=FLAGS.lambda1,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        ngf=FLAGS.ngf
    )
    G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()
    optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)

    summary_op = tf.summary.merge_all()
    train_writer = tf.summary.FileWriter(checkpoints_dir, graph)
    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    if FLAGS.load_model is not None:
      checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)
      meta_graph_path = checkpoint.model_checkpoint_path + ".meta"
      restore = tf.train.import_meta_graph(meta_graph_path)
      restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))
      step = int(meta_graph_path.split("-")[2].split(".")[0])
    else:
      sess.run(tf.global_variables_initializer())
      step = 0

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
      fake_Y_pool = ImagePool(FLAGS.pool_size)
      fake_X_pool = ImagePool(FLAGS.pool_size)

      while not coord.should_stop():
        # get previously generated images
        fake_y_val, fake_x_val = sess.run([fake_y, fake_x])

        # train
        _, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (
              sess.run(
                  [optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],
                  feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),
                             cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}
              )
        )
        if step % 100 == 0:
          train_writer.add_summary(summary, step)
          train_writer.flush()

        if step % 100 == 0:
          logging.info('-----------Step %d:-------------' % step)
          logging.info('  G_loss   : {}'.format(G_loss_val))
          logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))
          logging.info('  F_loss   : {}'.format(F_loss_val))
          logging.info('  D_X_loss : {}'.format(D_X_loss_val))

        if step % 1000 == 0:
          save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
          logging.info("Model saved in file: %s" % save_path)

        step += 1

    except KeyboardInterrupt:
      logging.info('Interrupted')
      coord.request_stop()
    except Exception as e:
      coord.request_stop(e)
    finally:
      save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)
      logging.info("Model saved in file: %s" % save_path)
      # When done, ask the threads to stop.
      coord.request_stop()
      coord.join(threads)

def main(unused_argv):
  train()

if __name__ == '__main__':
  logging.basicConfig(level=logging.INFO)
  tf.app.run()

测试源代码

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \
                       --input input_sample.jpg \
                       --output output_sample.jpg \
                       --image_size 256
"""

import tensorflow as tf
import os
from model import CycleGAN
import utils

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('model', 'model/apple2orange.pb', 'model path (.pb)')
tf.flags.DEFINE_string('input', 'samples/real_apple2orange_4.jpg', 'input image path (.jpg)')
tf.flags.DEFINE_string('output', 'output/output_sample3.jpg', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '256', 'image size, default: 256')

def inference():
  graph = tf.Graph()

  with graph.as_default():
    with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:
      image_data = f.read()
      input_image = tf.image.decode_jpeg(image_data, channels=3)
      input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))
      input_image = utils.convert2float(input_image)
      input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])

    with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
      graph_def = tf.GraphDef()
      graph_def.ParseFromString(model_file.read())
    [output_image] = tf.import_graph_def(graph_def,
                          input_map={'input_image': input_image},
                          return_elements=['output_image:0'],
                          name='output')

  with tf.Session(graph=graph) as sess:
    generated = output_image.eval()
    with open(FLAGS.output, 'wb') as f:
      f.write(generated)

def main(unused_argv):
  inference()

if __name__ == '__main__':
  tf.app.run()

实验结果

在这里是以相同物体不同模式下的数据集做训练(由于没有找到不同物体不同模式下的数据,当然你也可以自己做),从苹果到橘子的训练,测试结果如下:

这里写图片描述

从上图可以看出,苹果的颜色已经改成橘色,效果得到了体现。
源代码链接:CycleGAN source code

  • 18
    点赞
  • 246
    收藏
    觉得还不错? 一键收藏
  • 20
    评论
1.版本:matlab2014/2019a,内含运行结果,不会运行可私信 2.领域:智能优化算法、神经网络预测、信号处理、元胞自动机、像处理、路径规划、无人机等多种领域的Matlab仿真,更多内容可点击博主头像 3.内容:标题所示,对于介绍可点击主页搜索博客 4.适合人群:本科,硕士等教研学习使用 5.博客介绍:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可si信 ### 团队长期从事下列领域算法的研究和改进: ### 1 智能优化算法及应用 **1.1 改进智能优化算法方面(单目标和多目标)** **1.2 生产调度方面** 1.2.1 装配线调度研究 1.2.2 车间调度研究 1.2.3 生产线平衡研究 1.2.4 水库梯度调度研究 **1.3 路径规划方面** 1.3.1 旅行商问题研究(TSP、TSPTW) 1.3.2 各类车辆路径规划问题研究(vrp、VRPTW、CVRP) 1.3.3 机器人路径规划问题研究 1.3.4 无人机三维路径规划问题研究 1.3.5 多式联运问题研究 1.3.6 无人机结合车辆路径配送 **1.4 三维装箱求解** **1.5 物流选址研究** 1.5.1 背包问题 1.5.2 物流选址 1.5.4 货位优化 ##### 1.6 电力系统优化研究 1.6.1 微电网优化 1.6.2 配电网系统优化 1.6.3 配电网重构 1.6.4 有序充电 1.6.5 储能双层优化调度 1.6.6 储能优化配置 ### 2 神经网络回归预测、时序预测、分类清单 **2.1 bp预测和分类** **2.2 lssvm预测和分类** **2.3 svm预测和分类** **2.4 cnn预测和分类** ##### 2.5 ELM预测和分类 ##### 2.6 KELM预测和分类 **2.7 ELMAN预测和分类** ##### 2.8 LSTM预测和分类 **2.9 RBF预测和分类** ##### 2.10 DBN预测和分类 ##### 2.11 FNN预测 ##### 2.12 DELM预测和分类 ##### 2.13 BIlstm预测和分类 ##### 2.14 宽度学习预测和分类 ##### 2.15 模糊小波神经网络预测和分类 ##### 2.16 GRU预测和分类 ### 3 像处理算法 **3.1 像识别** 3.1.1 车牌、交通标志识别(新能源、国内外、复杂环境下车牌) 3.1.2 发票、身份证、银行卡识别 3.1.3 人脸类别和表情识别 3.1.4 打靶识别 3.1.5 字符识别(字母、数字、手写体、汉字、验证码) 3.1.6 病灶识别 3.1.7 花朵、药材、水果蔬菜识别 3.1.8 指纹、手势、虹膜识别 3.1.9 路面状态和裂缝识别 3.1.10 行为识别 3.1.11 万用表和表盘识别 3.1.12 人民币识别 3.1.13 答题卡识别 **3.2 像分割** **3.3 像检测** 3.3.1 显著性检测 3.3.2 缺陷检测 3.3.3 疲劳检测 3.3.4 病害检测 3.3.5 火灾检测 3.3.6 行人检测 3.3.7 水果分级 **3.4 像隐藏** **3.5 像去噪** **3.6 像融合** **3.7 像配准** **3.8 像增强** **3.9 像压缩** ##### 3.10 像重建 ### 4 信号处理算法 **4.1 信号识别** **4.2 信号检测** **4.3 信号嵌入和提取** **4.4 信号去噪** ##### 4.5 故障诊断 ##### 4.6 脑电信号 ##### 4.7 心电信号 ##### 4.8 肌电信号 ### 5 元胞自动机仿真 **5.1 模拟交通流** **5.2 模拟人群疏散** **5.3 模拟病毒扩散** **5.4 模拟晶体生长** ### 6 无线传感器网络 ##### 6.1 无线传感器定位 ##### 6.2 无线传感器覆盖优化 ##### 6.3 室内定位 ##### 6.4 无线传感器通信及优化 ##### 6.5 无人机通信中继优化
CycleGAN是一个无监督的像转换模型,可以将一种领域的像转换成另一种领域的像,而无需手动标注数据集。其核心思想是通过两个生成器和两个判别器,来实现两个领域之间的像转换。下面我们来看一下CycleGAN源代码解读。 CycleGAN的主要代码在`models`文件夹下,其中`cycle_gan_model.py`定义了CycleGAN的模型结构,`networks.py`定义了生成器和判别器的网络结构。其中生成器采用U-Net结构,判别器采用PatchGAN结构。`options`文件夹下的`base_options.py`定义了模型的一些基本参数,包括训练数据路径、模型保存路径、学习率等。`train_options.py`继承了`base_options.py`,并添加了一些训练相关的参数,比如迭代次数、是否使用L1损失等。`test_options.py`同样继承了`base_options.py`,并添加了一些测试相关的参数,比如测试数据路径、输出结果路径等。 在`train.py`文件中,我们可以看到CycleGAN的训练流程。首先定义了模型、数据加载器、优化器等,然后开始训练。在训练过程中,先通过生成器将A领域的片转换成B领域的片,然后将转换后的片与B领域的真实片送入判别器,计算判别器的损失。同时,也计算生成器的损失,包括对抗损失、循环一致性损失和L1损失。最后通过反向传播更新生成器和判别器的参数。 在`test.py`文件中,我们可以看到CycleGAN的测试流程。首先定义了模型和数据加载器,然后通过生成器将A领域的片转换成B领域的片,并将转换后的片保存到输出结果路径中。 总之,CycleGAN源代码实现了一个完整的无监督像转换模型,包括模型结构、数据加载、训练和测试流程。如果想要深入了解CycleGAN,可以从源代码入手,逐步理解其实现原理

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值