介绍了TensorFlow 的数据读取的基本原理,再来看如何i卖取CIFAR-10数据。在CIFAR-10 数据集中,文件data batch I .bin 、data batch 2.bin 、data_batch 5 . bin 和test_ batch.bin 中各有10000 个样本。一个样本由3073 个字节组成,第一个字节为标签( label ),剩下3072 个字节为图像数据(官方说明文档)。
样本和样本之间没高多余的字节分割, 因此这几个二进制文件的大小都是30730000 字节。如何用TensorFlow读取CIFAR-10 数据呢?,步骤和上一篇文章(TensorFlow的数据读取机制)一样。
- 第一步,用tf. train .string_ input producer 建立队列。
- 第二步,通过reader.read 读数据。在上一篇文章中,一个文件就是一张图片,因此用的reader 是tf. WholeFileReader() 。CIFAR-10 数据是以固定字节存在文件中的,一个文件中含再多个样本3 因此不能使用tf. WholeFileReader (),而是用tf.FixedLengthRecordReader() 。
- 第三步,调用tf. train . start_ queue_ runners 。
- 最后,通过sess.run()取出图片结果。
遵循上面的步骤,本文会做一个实验:将CIFAR-10 数据集中的图片读取出来,并保存为.jpg 恪式。对应的程序为cifar 10 extract. py 。看步骤中的tf. train.string_ input_produ cer ,tf.FixedLengthRecordReader ()、tf.train.start_queue_ runners 、sess.run ()都在什么地方。按照程序的执行顺序来看:
#coding: utf-8
# 导入当前目录的cifar10_input,这个模块负责读入cifar10数据
import cifar10_input
# 导入TensorFlow和其他一些可能用到的模块。
import tensorflow as tf
import os
import scipy.misc
def inputs_origin(data_dir):
# filenames一共5个,从data_batch_1.bin到data_batch_5.bin
# 读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# 将文件名的list包装成TensorFlow中queue的形式
filename_queue = tf.train.string_input_producer(filenames)
# cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
# 返回的结果read_input的属性uint8image就是图像的Tensor
read_input = cifar10_input.read_cifar10(filename_queue)
# 将图片转换为实数形式
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
# 返回的reshaped_image是一张图片的tensor
# 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
return reshaped_image
if __name__ == '__main__':
# 创建一个会话sess
with tf.Session() as sess:
# 调用inputs_origin。cifar10_data/cifar-10-batches-bin是我们下载的数据的文件夹位置
reshaped_image = inputs_origin('cifar10_data/cifar-10-batches-bin')
# 这一步start_queue_runner很重要。
# 我们之前有filename_queue = tf.train.string_input_producer(filenames)
# 这个queue必须通过start_queue_runners才能启动
# 缺少start_queue_runners程序将不能执行
threads = tf.train.start_queue_runners(sess=sess)
# 变量初始化
sess.run(tf.global_variables_initializer())
# 创建文件夹cifar10_data/raw/
if not os.path.exists('cifar10_data/raw/'):
os.makedirs('cifar10_data/raw/')
# 保存30张图片
for i in range(30):
# 每次sess.run(reshaped_image),都会取出一张图片
image_array = sess.run(reshaped_image)
# 将图片保存
scipy.misc.toimage(image_array).save('cifar10_data/raw/%d.jpg' % i)
inputs_ origin 是一个函数。这个函数中包含了前两个步骤, tf.train.string_inpu t_pr:oducer 和使用reader 。函数的返回值reshaped_image 是一个Tensor,对应一张训练图像。下面要做的并不是直接运行sess.run(reshaped_image),而是使用threads = tf. train. start_ queue_ runners( sess=sess)。只高调用过tf. train.start_ queue_ runners 后,才会让系统中的所高队列真正地“运行”,开始从文件中读数据。如果不调用这条i吾旬,系统将会一直等待。
最后用sess.run(reshaped_image)取出训练图片并保存。此程序一共在文件夹cifar10data/raw/中保存了30 张图片。读者可以打开该文件夹,看到原始的CIFAR-10 训练图片。再回过头来看inputs_ origin 函数:
def inputs_origin(data_dir):
# filenames一共5个,从data_batch_1.bin到data_batch_5.bin
# 读入的都是训练图像
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
for i in xrange(1, 6)]
# 判断文件是否存在
for f in filenames:
if not tf.gfile.Exists(f):
raise ValueError('Failed to find file: ' + f)
# 将文件名的list包装成TensorFlow中queue的形式
filename_queue = tf.train.string_input_producer(filenames)
# cifar10_input.read_cifar10是事先写好的从queue中读取文件的函数
# 返回的结果read_input的属性uint8image就是图像的Tensor
read_input = cifar10_input.read_cifar10(filename_queue)
# 将图片转换为实数形式
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
# 返回的reshaped_image是一张图片的tensor
# 我们应当这样理解reshaped_image:每次使用sess.run(reshaped_image),就会取出一张图片
return reshaped_image
tf.train.string_input_producer(filenames )创建了一个文件名队列,真中filenames 是一个列表,包含从data_batch_1.bin 到data_batch_5.bin 一共5 个文件名。这正好对应了CIFAR-10 的训练集。cifar10_ input.read_ cifar_10(filename_queue)对应“使用reader ”的步骤。为此需要查看cifar10_input.py中的read cifar10函数,其中关键的代码如下。
# Dimensions of the images in the CIFAR-10 dataset.
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
# input format.
label_bytes = 1 # 2 for CIFAR-100
result.height = 32
result.width = 32
result.depth = 3
image_bytes = result.height * result.width * result.depth
# Every record consists of a label followed by the image, with a
# fixed number of bytes for each.
record_bytes = label_bytes + image_bytes
# Read a record, getting filenames from the filename_queue. No
# header or footer in the CIFAR-10 format, so we leave header_bytes
# and footer_bytes at their default of 0.
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
result.key, value = reader.read(filename_queue)
i吾句tf.FixedLengthRecordReader(record_ bytes=record _bytes )创建了一个reader , 包每次在文件中读取record_bytes 字节的数据,直到文件结束。结合代码, record_bytes 就等于1+32*32*3,即3073 ,正好对应CIFAR-10中一个样本的字节长度。使用reader.read(filename_queue)后, reader 从之前建立好的文件名队列中渎职数据(以Tensor 的形式)。简单处理结果后由函数返回。至此,读者应当对CIFAR-10 数据的读取流程及TensorFlow 的读取机制相当熟悉了。