j基于tensorflow2的cyclegan(马和斑马的图片生成)

本教程演示了在没有配对训练样本的情况下,如何使用条件对抗网络进行无配对图像到图像的转换,即CycleGAN。CycleGAN通过循环一致性损失来训练,即使在没有源和目标域之间的一对一映射的情况下,也能将一个域的特征转换到另一个域。主要步骤包括设置输入流水线、导入和重用Pix2Pix模型、定义损失函数、检查点、训练以及使用测试数据集生成图像。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

Copyright 2019 The TensorFlow Authors.
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

CycleGAN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook demonstrates unpaired image to image translation using conditional GAN’s, as described in Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks, also known as CycleGAN. The paper proposes a method that can capture the characteristics of one image domain and figure out how these characteristics could be translated into another image domain, all in the absence of any paired training examples.

This notebook assumes you are familiar with Pix2Pix, which you can learn about in the Pix2Pix tutorial. The code for CycleGAN is similar, the main difference is an additional loss function, and the use of unpaired training data.

CycleGAN uses a cycle consistency loss to enable training without the need for paired data. In other words, it can translate from one domain to another without a one-to-one mapping between the source and target domain.

This opens up the possibility to do a lot of interesting tasks like photo-enhancement, image colorization, style transfer, etc. All you need is the source and the target dataset (which is simply a directory of images).

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4Xj4pWsT-1591019588845)(images/horse2zebra_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-YWmoHz0z-1591019588847)(images/horse2zebra_2.png)]

Set up the input pipeline

Install the tensorflow_examples package that enables importing of the generator and the discriminator.

#!pip install -q git+https://github.com/tensorflow/examples.git
#!pip install tensorflow_datasets
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

Input Pipeline

This tutorial trains a model to translate from images of horses, to images of zebras. You can find this dataset and similar ones here.

As mentioned in the paper, apply random jittering and mirroring to the training dataset. These are some of the image augmentation techniques that avoids overfitting.

This is similar to what was done in pix2pix

  • In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256.
  • In random mirroring, the image is randomly flipped horizontally i.e left to right.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
[1mDownloading and preparing dataset cycle_gan/horse2zebra/2.0.0 (download: 111.45 MiB, generated: Unknown size, total: 111.45 MiB) to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0...[0m


d:\python_virtualenv\tf2.1_gpu\lib\site-packages\urllib3\connectionpool.py:986: InsecureRequestWarning: Unverified HTTPS request is being made to host '127.0.0.1'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings
  InsecureRequestWarning,


Shuffling and writing examples to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0.incomplete460VY4\cycle_gan-trainA.tfrecord
Shuffling and writing examples to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0.incomplete460VY4\cycle_gan-trainB.tfrecord
Shuffling and writing examples to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0.incomplete460VY4\cycle_gan-testA.tfrecord
Shuffling and writing examples to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0.incomplete460VY4\cycle_gan-testB.tfrecord
[1mDataset cycle_gan downloaded and prepared to C:\Users\Administrator\tensorflow_datasets\cycle_gan\horse2zebra\2.0.0. Subsequent calls will reuse this data.[0m
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

train_zebras = train_zebras.map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(1)
sample_horse = next(iter(train_horses
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值