数据读取部分代码
data_provider.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing the CIFAR data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import os
import numpy as np
#from slim.datasets import dataset_factory as datasets
from datasets import dataset_factory as datasets
slim = tf.contrib.slim
def provide_data(batch_size, dataset_dir, dataset_name='cifar10', #dataset_name='cifar10 mnist',
split_name='train', one_hot=True):
"""Provides batches of CIFAR data.
Args:
batch_size: The number of images in each batch.
dataset_dir: The directory where the CIFAR10 data can be found. If `None`,
use default.
dataset_name: Name of the dataset.
split_name: Should be either 'train' or 'test'.
one_hot: Output one hot vector instead of int32 label.
Returns:
images: A `Tensor` of size [batch_size, 32, 32, 3]. Output pixel values are
in [-1, 1].
labels: Either (1) one_hot_labels if `one_hot` is `True`
A `Tensor` of size [batch_size, num_classes], where each row has a
single element set to one and the rest set to zeros.
Or (2) labels if `one_hot` is `False`
A `Tensor` of size [batch_size], holding the labels as integers.
num_samples: The number of total samples in the dataset.
num_classes: The number of classes in the dataset.
Raises:
ValueError: if the split_name is not either 'train' or 'test'.
"""
print("provide_data.............")
dataset = datasets.get_dataset(
dataset_name, split_name, dataset_dir=dataset_dir)
provider = slim.dataset_data_provider.DatasetDataProvider(
dataset,
common_queue_capacity=5 * batch_size,
common_queue_min=batch_size,
shuffle=(split_name == 'train'))
[image, label] = provider.get(['image', 'label'])
# Preprocess the images.
image = (tf.to_float(image) - 128.0) / 128.0
#image = tf.image.resize_images(image, [32, 32], method=0)
#image = tf.image.grayscale_to_rgb(image)
#image = tf.cast(image,tf.float32)
print("---------------image:",image)
# Creates a QueueRunner for the pre-fetching operation.
images, labels = tf.train.batch(
[image, label],
batch_size=batch_size,
num_threads=32,
capacity=5 * batch_size)
labels = tf.reshape(labels, [-1])
if one_hot:
labels = tf.one_hot(labels, dataset.num_classes)
print("num_samples:",dataset.num_samples)
print("dataset.num_classes",dataset.num_classes)
return images, labels, dataset.num_samples, dataset.num_classes
def getimage(path):
list=[]
for filename in os.listdir(path): # listdir的参数是文件夹的路径
if filename.endswith(".jpg") or filename.endswith(".jpeg"):
print(filename)
list.append(path+filename)
return list
def provide_data_self(path):
i = 0
images = []
labels = []
paths = getimage(path)
with tf.Session() as sess:
for img_path in paths :
i+=1
print("image path:",img_path)
img = tf.gfile.FastGFile(str(img_path), 'rb').read()
img = tf.image.decode_jpeg(img)
# 图片归一化,[0,1],浮点类型数据。因为为了将图片数据能够保存到 TFRecord 结构体中,所以需要将其图片矩阵转换成 string,
# 所以为了在使用时能够转换回来,这里确定下数据格式为 tf.float32
img = tf.image.convert_image_dtype(img, dtype=tf.float32)
# 把图片转换成希望的大小,由于本例子中两张图片大小都是650*434,所以此步骤可以省略。要注意的时候resize_images中输入图片的宽、高顺序
img = tf.image.resize_images(img, [128, 128], method=0)
img = sess.run(img)
images.append(img)
if i >50 :
labels.append(1)
else:
labels.append(0)
labels = np.asarray(labels)
labels = tf.one_hot(labels, 10)
images = np.asarray(images)
print("----------labels:",labels.shape)
print("----------images:",images.shape)
return labels,images
def get_batch_data(path,batchSize):
label, images = provide_data_self(path)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False,num_epochs=None)
image_batch, label_batch = tf.train.batch(input_queue, batch_size=batchSize, num_threads=1,
capacity=64,allow_smaller_final_batch=False)
return image_batch,label_batch
def float_image_to_uint8(image):
"""Convert float image in [-1, 1) to [0, 255] uint8.
Note that `1` gets mapped to `0`, but `1 - epsilon` gets mapped to 255.
Args:
image: An image tensor. Values should be in [-1, 1).
Returns:
Input image cast to uint8 and with integer values in [0, 255].
"""
image = (image * 128.0) + 128.0
return tf.cast(image, tf.uint8)
网络构建
network.py
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Networks for GAN CIFAR example using TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
#from slim.nets import dcgan
from nets import dcgan
tfgan = tf.contrib.gan
layers = tf.contrib.layers
def _last_conv_layer(end_points):
""""Returns the last convolutional layer from an endpoints dictionary."""
conv_list = [k if k[:4] == 'conv' else None for k in end_points.keys()]
conv_list