基于tensorflow2的pix2pix图像生成(建筑物还原,U-net,gan)

本教程演示了使用条件对抗网络进行图像到图像转换的Pix2Pix技术,适用于黑白照片上色、谷歌地图转谷歌地球等。这里通过CMP Facade数据库,将建筑立面转化为真实建筑。文章介绍了数据加载、生成器、判别器的构建,以及训练和测试过程。训练200个周期后,模型能生成逼真的建筑图像。
摘要由CSDN通过智能技术生成
Copyright 2019 The TensorFlow Authors.

Licensed under the Apache License, Version 2.0 (the “License”);

#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Pix2Pix

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook demonstrates image to image translation using conditional GAN’s, as described in Image-to-Image Translation with Conditional Adversarial Networks. Using this technique we can colorize black and white photos, convert google maps to google earth, etc. Here, we convert building facades to real buildings.

In example, we will use the CMP Facade Database, helpfully provided by the Center for Machine Perception at the Czech Technical University in Prague. To keep our example short, we will use a preprocessed copy of this dataset, created by the authors of the paper above.

Each epoch takes around 15 seconds on a single V100 GPU.

Below is the output generated after training the model for 200 epochs.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MHUGEvrq-1591008034678)(https://www.tensorflow.org/images/gan/pix2pix_1.png)]
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-CEQTA6Y7-1591008034689)(https://www.tensorflow.org/images/gan/pix2pix_2.png)]

Import TensorFlow and other libraries

import tensorflow as tf

import os
import time

from matplotlib import pyplot as plt
from IPython import display
!pip install -q -U tensorboard

Load the dataset

You can download this dataset and similar datasets from here. As mentioned in the paper we apply random jittering and mirroring to the training dataset.

  • In random jittering, the image is resized to 286 x 286 and then randomly cropped to 256 x 256
  • In random mirroring, the image is randomly flipped horizontally i.e left to right.
_URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/facades.tar.gz'

path_to_zip = tf.keras.utils.get_file('facades.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), 'facades/')
BUFFER_SIZE = 400
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def load(image_file):
  image = tf.io.read_file(image_file)
  image = tf.image.decode_jpeg(image)

  w = tf.shape(image)[1]

  w = w // 2
  real_image = image[:, :w, :]
  input_image = image[:, w:, :]

  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image
inp, re = load(PATH+'train/100.jpg')
# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(inp/255.0)
plt.figure()
plt.imshow(re/255.0)
<matplotlib.image.AxesImage at 0x1b2c1d68>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LqZwbBHP-1591008034702)(output_12_1.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VSE01ewQ-1591008034704)(output_12_2.png)]

def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
# normalizing the images to [-1, 1]

def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
  # resizing to 286 x 286 x 3
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # randomly cropping to 256 x 256 x 3
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

As you can see in the images below
that they ar

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值