下载即用。一定要注意,首先要下载cifar数据集,解压放在datasets文件夹下。
针对二进制文件的读取
import tensorflow as tf
from tensorflow import flags
import os
from scipy import misc
flags.DEFINE_string('data_dir','datasets/',"""Path to the CIFAR-10 data directory.""")
flags.DEFINE_boolean('log_device_placement', False,
"""Whether to log device placement.""")
FLAGS=flags.FLAGS
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN=500
def _generate_image_and_label_batch(image, label, min_queue_examples,
batch_size):
num_preprocess_threads = 16
images, label_batch = tf.train.shuffle_batch(
[image,label],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
tf.summary.image('images', images)
return images, tf.reshape(label_batch, [batch_size])
def read_cifar10(filename_queue):
class CIFAR10Record(object):
pass
result = CIFAR10Record()
label_bytes = 1
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
record_bytes = label_bytes + image_bytes
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
[result.depth, result.height, result.width])
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
return result
def distorted_inputs(data_dir, batch_size):
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in range(1, 6)]
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
filename_queue = tf.train.string_input_producer(filenames)
read_input = read_cifar10(filename_queue) #返回一个类
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
height = 24
width = 24
distorted_image = tf.random_crop(reshaped_image, [height, width,3])
distorted_image = tf.image.random_flip_left_right(distorted_image)
distorted_image = tf.image.random_brightness(distorted_image,
max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,
lower=0.2, upper=1.8)
float_image = tf.image.per_image_standardization(distorted_image)
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
min_fraction_of_examples_in_queue)
print ('Filling queue with %d CIFAR images before starting to train. '
'This will take a few minutes.' % min_queue_examples)
return _generate_image_and_label_batch(float_image, read_input.label,
min_queue_examples, batch_size)
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels=distorted_inputs(data_dir=data_dir,batch_size=128)
sess = tf.Session(config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement))
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
for i in range(100):
images_value,labels_value=sess.run([images, labels])
for j in range(128):
misc.imsave('photo/'+'%d_%d_%d'%(i,j,labels_value[j])+'.png',images_value[j])