tf2.0 cycle-gan,官方代码复现整理。

82 篇文章 46 订阅 ¥59.90 ¥99.00
27 篇文章 0 订阅

官方数据集网址:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/

官方教程:https://tensorflow.google.cn/tutorials

pix2pix.py: 在源目录创建pix2pix文件夹,将pix2pix.py文件放入其中。https://download.csdn.net/download/qq_38784454/14016118

官方并没有提供自主测试文件,我提供的test代码:https://download.csdn.net/download/qq_38784454/18207785

cycle-gan代码(已修改,可用):(官方代码会无限占用GPU显存,我添加了显存自适应占用设置。)

import tensorflow as tf
import tensorflow_datasets as tfds
from pix2pix import pix2pix

import os

# os.environ['CUDA_VISIBLE_DEVICES'] = '/device:GPU:0'

import time, argparse
import matplotlib.pyplot as plt
from IPython.display import clear_output


# GPU内存占用设置
physical_gpus = tf.config.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_gpus[0],True)

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE
# dataset, metadata = tfds.load(name='horse2zebra',data_dir='E:\\Users\\CycleGAN-tf2.0-tourtial\\dataset',
#                               with_info=True, as_supervised=True)

# train_horses, train_zebras = dataset['trainA'], dataset['trainB']
# test_horses, test_zebras = dataset['testA'], dataset['testB']
def arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--DATASET", default='EL\\finger', type=str)
    # parser.add_argument("--DATASET", default='glass2sheep', type=str)
    args = parser.parse_args()
    return args

args = arg_parser()
dataset_path = os.path.join(os.getcwd(), 'dataset\\' + args.DATASET + '\\')
train_horses = tf.data.Dataset.list_files(dataset_path + 'trainA\\*')
train_zebras = tf.data.Dataset.list_files(dataset_path + 'trainB\\*')
test_horses = tf.data.Dataset.list_files(dataset_path + 'testA_paper\\*')
test_zebras = tf.data.Dataset.list_files(dataset_path + 'testB_paper\\*')

BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# 将图像归一化到区间 [-1, 1] 内。

def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # 调整大小为 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # 随机裁剪到 256 x 256 x 3
  image = random_crop(image)

  # 随机镜像
  image = tf.image.random_flip_left_right(image)

  return image

def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  image = tf.cast(image, tf.float32)

  return image

def preprocess_image_train(image_file):
  image = load(image_file)
  image = random_jitter(image)
  image = normalize(image)
  return image

def resize(input_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image

def preprocess_image_test(image_file):
  image = load(image_file)
  image = resize(image, IMG_WIDTH, IMG_HEIGHT)
  image = normalize(image)
  return image

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)


OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# 如果存在检查点,恢复最新版本检查点
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

EPOCHS = 60

def generate_images(model, test_input, save_path, epoch):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # 获取范围在 [0, 1] 之间的像素值以绘制它。
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')

  # plt.show()
  plt.savefig(save_path+str(epoch)+'_test.jpg')


def train_step(real_x, real_y):
  # persistent 设置为 Ture,因为 GradientTape 被多次应用于计算梯度。
  with tf.GradientTape(persistent=True) as tape:
    # 生成器 G 转换 X -> Y。
    # 生成器 F 转换 Y -> X。

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x 和 same_y 用于一致性损失。
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # 计算损失。
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # 总生成器损失 = 对抗性损失 + 循环损失。
    total_gen_g_loss = gen_g_loss + total_cycle_loss 
    total_gen_f_loss = gen_f_loss + total_cycle_loss
    # total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    # total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # 计算生成器和判别器损失。
  generator_g_gradients = tape.gradient(total_gen_g_loss,
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss,
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss,
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss,
                                            discriminator_y.trainable_variables)

  # 将梯度应用于优化器。
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients,
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients,
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                               discriminator_y.trainable_variables))

for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('step=%d' % n)
    n+=1

  clear_output(wait=True)
  # 使用一致的图像(sample_horse),以便模型的进度清晰可见。
  save_path = 'E:\\Users\\CycleGAN-tf2.0-original\\samples\\'
  generate_images(generator_g, sample_horse, save_path, epoch)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

# 在测试数据集上运行训练的模型。
test_num = 0
save_path_testAB = 'E:\\Users\\CycleGAN-tf2.0-original\\test_save\\AB\\'
for inp in test_horses.take(25):

  generate_images(generator_g, inp, save_path_testAB, test_num)
  test_num = test_num+1

save_path_testBA = 'E:\\Users\\CycleGAN-tf2.0-original\\test_save\\BA\\'
for inp in test_zebras.take(25):

  generate_images(generator_f, inp, save_path_testBA, test_num)
  test_num = test_num+1

此代码生成器是U-Net, 使用了instance normalization。比原版代码的生成效果更加稳定。

 

  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

苏打水的杯子

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值