(三)使用Keras构建移动风格迁移CycleGAN

目录

介绍

处理数据集

构建生成器和鉴别器

下一步


介绍

在本系列文章中,我们将展示一个基于循环一致对抗网络 (CycleGAN)的移动图像到图像转换系统。我们将构建一个CycleGAN,它可以执行不成对的图像到图像的转换,并向您展示一些有趣但具有学术深度的例子。我们还将讨论如何将这种使用TensorFlowKeras构建的训练有素的网络转换为TensorFlow Lite并用作移动设备上的应用程序。

我们假设您熟悉深度学习的概念,以及Jupyter NotebooksTensorFlow。欢迎您下载项目代码。

上一篇文章中​​​​​​​,我们讨论了CycleGAN架构。现在我们完成了理论。在本文中,我们将从头开始实现CycleGAN

我们的CycleGAN将使用马到斑马数据集执行未配对的图像到图像的转换,您可以下载该数据集。我们将使用TensorFlowKeras实现我们的网络,以及来自Pix.Pix库的生成器和鉴别器。我们将通过tensorflow_examples包导入生成器和鉴别器以简化实现。但是,在随后的一篇文章中,我们还将向您展示如何从头开始构建新的生成器和鉴别器。

值得一提的是,CycleGAN是一个非常耗电和内存消耗的网络。您的系统必须具有至少8 GB的足够RAM和与GTX 1660 Ti相同或更好的GPU,才能训练和运行CycleGAN,而不会出现内存不足错误或超时。

我们将使用GoogleColab训练我们的网络,这是一种托管的Jupyter Notebook服务,可以免费访问计算资源,包括GPU。最重要的是,它是免费的,不像其他一些云计算服务。

处理数据集

让我们加载数据集并应用一些预处理技术,例如裁剪、抖动和镜像,这将帮助我们避免网络过度拟合:

  • 图像抖动将图像大小调整为286 x 286像素,然后从随机选择的原点将其裁剪为256 x 256像素
  • 图像镜像从左到右水平翻转图像。

上述技术在原始CycleGAN论文中有所描述。

我们会将我们的数据上传到Google云端硬盘,以便Google Colab可以访问这些数据。数据上传后,我们就可以开始读取数据了。或者,您可以简单地在代码中使用tfds.load直接从TensorFlow数据集包中加载数据集,如下所示。

首先,让我们导入一些必需的依赖项:

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

现在我们将下载数据集并将上面讨论的增强技术应用于它:

dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

加载数据后,让我们添加一些预处理功能:

def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image

# normalizing images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image

def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # randomly mirroring
  image = tf.image.random_flip_left_right(image)

  return image

def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image

def preprocess_image_test(image, label):
  image = normalize(image)
  return image

现在,我们将读取图像:

train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
############################Mirroring and jittering
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random mirroring')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)

这是抖动图像的示例。

构建生成器和鉴别器

现在,我们从pix2pix模型中导入生成器和鉴别器。我们将使用基于U-Net的生成器,而不是CycleGAN论文中使用的残差块。我们将使用U-Net,因为它比Residual块具有更简单的结构并且需要更少的计算。但是,我们将在另一篇文章中发现基于残差块的生成器。

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

有了生成器和鉴别器,我们就可以开始设置损失了。由于CycleGAN是不成对的图像到图像的转换,因此不需要成对的数据来训练网络。因此,没有人能保证在训练过程中输入和目标图像是有意义的一对。这就是为什么计算循环一致性损失以使网络映射正确很重要的原因:

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

现在,我们计算循环一致性损失以确保转换结果接近原始图像:

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

最后,我们为生成器和鉴别器设置优化器:

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

下一步

接下来的文章中​​​​​​​,我们将向你展示如何训练我们的CycleGAN以转换马-斑马和斑马-马。

https://www.codeproject.com/Articles/5304922/Building-a-Mobile-Style-Transfer-CycleGAN-with-Ker

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值