tensorflow如何队列式同步批量读取照片(1)

下载即用。一定要注意,首先要下载cifar数据集,解压放在datasets文件夹下。
针对二进制文件的读取

import tensorflow as tf
from tensorflow import flags
import os
from scipy import misc
flags.DEFINE_string('data_dir','datasets/',"""Path to the CIFAR-10 data directory.""")
flags.DEFINE_boolean('log_device_placement', False,
                            """Whether to log device placement.""")
FLAGS=flags.FLAGS
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN=500
def _generate_image_and_label_batch(image, label, min_queue_examples,
                                    batch_size):

  num_preprocess_threads = 16
  images, label_batch = tf.train.shuffle_batch(
      [image,label],
      batch_size=batch_size,
      num_threads=num_preprocess_threads,
      capacity=min_queue_examples + 3 * batch_size,
      min_after_dequeue=min_queue_examples)

  tf.summary.image('images', images)

  return images, tf.reshape(label_batch, [batch_size])
def read_cifar10(filename_queue):

  class CIFAR10Record(object):
    pass
  result = CIFAR10Record()


  label_bytes = 1  
  result.height = 32
  result.width = 32
  result.depth = 3
  image_bytes = result.height * result.width * result.depth

  record_bytes = label_bytes + image_bytes

  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  result.key, value = reader.read(filename_queue)

  record_bytes = tf.decode_raw(value, tf.uint8)

  result.label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32) 

  depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),
                           [result.depth, result.height, result.width])
  result.uint8image = tf.transpose(depth_major, [1, 2, 0])


  return result
def distorted_inputs(data_dir, batch_size):

  filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
               for i in range(1, 6)]
  for f in filenames:
    if not tf.gfile.Exists(f):
      raise ValueError('Failed to find file: ' + f)

  filename_queue = tf.train.string_input_producer(filenames)

  read_input = read_cifar10(filename_queue)  #返回一个类
  reshaped_image = tf.cast(read_input.uint8image, tf.float32)

  height = 24
  width = 24

  distorted_image = tf.random_crop(reshaped_image, [height, width,3])

  distorted_image = tf.image.random_flip_left_right(distorted_image)

  distorted_image = tf.image.random_brightness(distorted_image,
                                               max_delta=63)
  distorted_image = tf.image.random_contrast(distorted_image,
                                             lower=0.2, upper=1.8)

  float_image = tf.image.per_image_standardization(distorted_image)

  min_fraction_of_examples_in_queue = 0.4
  min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                           min_fraction_of_examples_in_queue)
  print ('Filling queue with %d CIFAR images before starting to train. '
         'This will take a few minutes.' % min_queue_examples)

  return _generate_image_and_label_batch(float_image, read_input.label,
                                         min_queue_examples, batch_size)


data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
images, labels=distorted_inputs(data_dir=data_dir,batch_size=128)

sess = tf.Session(config=tf.ConfigProto(
          log_device_placement=FLAGS.log_device_placement))
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
for i in range(100):
  images_value,labels_value=sess.run([images, labels])
  for j in range(128):
    misc.imsave('photo/'+'%d_%d_%d'%(i,j,labels_value[j])+'.png',images_value[j])
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值