GAN 实例

本文介绍了使用GAN进行图像生成的实践过程,包括数据读取、网络构建、训练及评估等关键步骤。通过data_provider.py模块加载数据,network.py中定义生成器和判别器网络结构,train.py文件实现模型的训练,eval.py则用于模型效果的评估,而util.py为通用工具类,辅助整个流程的执行。
摘要由CSDN通过智能技术生成

数据读取部分代码

data_provider.py

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the CIFAR data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import tensorflow as tf
import os
import numpy as np
#from slim.datasets import dataset_factory as datasets
from datasets import dataset_factory as datasets
slim = tf.contrib.slim


def provide_data(batch_size, dataset_dir, dataset_name='cifar10', #dataset_name='cifar10 mnist',
                 split_name='train', one_hot=True):
  """Provides batches of CIFAR data.

  Args:
    batch_size: The number of images in each batch.
    dataset_dir: The directory where the CIFAR10 data can be found. If `None`,
      use default.
    dataset_name: Name of the dataset.
    split_name: Should be either 'train' or 'test'.
    one_hot: Output one hot vector instead of int32 label.

  Returns:
    images: A `Tensor` of size [batch_size, 32, 32, 3]. Output pixel values are
      in [-1, 1].
    labels: Either (1) one_hot_labels if `one_hot` is `True`
            A `Tensor` of size [batch_size, num_classes], where each row has a
            single element set to one and the rest set to zeros.
            Or (2) labels if `one_hot` is `False`
            A `Tensor` of size [batch_size], holding the labels as integers.
    num_samples: The number of total samples in the dataset.
    num_classes: The number of classes in the dataset.

  Raises:
    ValueError: if the split_name is not either 'train' or 'test'.
  """

  print("provide_data.............")

  dataset = datasets.get_dataset(
      dataset_name, split_name, dataset_dir=dataset_dir)
  provider = slim.dataset_data_provider.DatasetDataProvider(
      dataset,
      common_queue_capacity=5 * batch_size,
      common_queue_min=batch_size,
      shuffle=(split_name == 'train'))
  [image, label] = provider.get(['image', 'label'])

  # Preprocess the images.
  image = (tf.to_float(image) - 128.0) / 128.0

  #image = tf.image.resize_images(image, [32, 32], method=0)
  #image = tf.image.grayscale_to_rgb(image)
  #image = tf.cast(image,tf.float32)

  print("---------------image:",image)

  # Creates a QueueRunner for the pre-fetching operation.
  images, labels = tf.train.batch(
      [image, label],
      batch_size=batch_size,
      num_threads=32,
      capacity=5 * batch_size)

  labels = tf.reshape(labels, [-1])

  if one_hot:
    labels = tf.one_hot(labels, dataset.num_classes)
  print("num_samples:",dataset.num_samples)
  print("dataset.num_classes",dataset.num_classes)
  return images, labels, dataset.num_samples, dataset.num_classes


def getimage(path):
    list=[]
    for filename in os.listdir(path):  # listdir的参数是文件夹的路径
        if filename.endswith(".jpg") or filename.endswith(".jpeg"):
            print(filename)
            list.append(path+filename)
    return list

def provide_data_self(path):
    i = 0
    images = []
    labels = []
    paths = getimage(path)
    with tf.Session() as sess:
        for img_path in paths :
            i+=1
            print("image path:",img_path)
            img = tf.gfile.FastGFile(str(img_path), 'rb').read()
            img = tf.image.decode_jpeg(img)
        # 图片归一化,[0,1],浮点类型数据。因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,
        # 所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32
            img = tf.image.convert_image_dtype(img, dtype=tf.float32)
        # 把图片转换成希望的大小,由于本例子中两张图片大小都是650*434,所以此步骤可以省略。要注意的时候resize_images中输入图片的宽、高顺序
            img = tf.image.resize_images(img, [128, 128], method=0)
            img = sess.run(img)
            images.append(img)
            if i >50 :
                labels.append(1)
            else:
                labels.append(0)

    labels = np.asarray(labels)
    labels = tf.one_hot(labels, 10)
    images = np.asarray(images)
    print("----------labels:",labels.shape)
    print("----------images:",images.shape)
    return labels,images


def get_batch_data(path,batchSize):
    label, images = provide_data_self(path)
    input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=None)
    image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchSize, num_threads=1,
                                              capacity=64,allow_smaller_final_batch=False)
    return image_batch,label_batch


def float_image_to_uint8(image):
  """Convert float image in [-1, 1) to [0, 255] uint8.

  Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.

  Args:
    image: An image tensor. Values should be in [-1, 1).

  Returns:
    Input image cast to uint8 and with integer values in [0, 255].
  """
  image = (image * 128.0) + 128.0
  return tf.cast(image, tf.uint8)

网络构建

network.py

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

#from slim.nets import dcgan
from nets import dcgan
tfgan = tf.contrib.gan
layers = tf.contrib.layers

def _last_conv_layer(end_points):
  """"Returns the last convolutional layer from an endpoints dictionary."""
  conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
  conv_list
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值