利用Tensorflow的队列多线程读取数据
在tensorflow中,有三种方式输入数据
1利用feed_dict送入numpy数组
2利用队列从文件中直接读取数据
3预加载数据
其中第一种方式很常用,在tensorflow的MNIST训练源码中可以看到,通过feed_dict={},可以将任意数据送入tensor中。
第二种方式相比于第一种,速度更快,可以利用多线程的优势把数据送入队列,再以batch的方式出队,并且在这个过程中可以很方便地对图像进行随机裁剪、翻转、改变对比度等预处理,同时可以选择是否对数据随机打乱,可以说是非常方便。该部分的源码在tensorflow官方的CIFAR-10训练源码中可以看到,但是对于刚学习tensorflow的人来说,比较难以理解,本篇博客就当成我调试完成后写的一篇总结,以防自己再忘记具体细节。
读取CIFAR-10数据集
按照第一种方式的话,CIFAR-10的读取只需要写一段非常简单的代码即可将测试集与训练集中的图像分别读取:
path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
# extract train examples
num_train_examples = 50000
x_train = np.empty((num_train_examples, 32, 32, 3), dtype='uint8')
y_train = np.empty((num_train_examples), dtype='uint8')
for i in range(1, 6):
fpath = os.path.join(path, 'data_batch_' + str(i))
(x_train[(i - 1) * 10000: i * 10000, :, :, :], y_train[(i - 1) * 10000: i * 10000]) = load_and_decode(fpath)
# extract test examples
fpath = os.path.join(path, 'test_batch')
x_test, y_test = load_and_decode(fpath)
return x_train, y_train, x_test, np.array(y_test)
其中load_and_decode函数只需要按照CIFAR-10官网给出的方式decode就行,最终返回的x_train是一个[50000, 32, 32, 3]的ndarray,但对于ndarray来说,进行预处理就要麻烦很多,为了取mini-SGD的batch,还自己写了一个类,通过调用train_set.next_batch()函数来取,总而言之就是什么都要自己动手,效率确实不高
但对于第二种方式,读取起来就要麻烦很多,但使用起来,又快又方便
以读取cifar10为例。
**#1、读取文件,生成文件名列表**
path = 'E:\Dataset\cifar-10\cifar-10-batches-py'
filenames = [os.path.join(path, 'data_batch_%d' % i) for i in range(1, 6)]
**#2、利用tf.train.string_input_producer函数生成一个读取队列**
filename_queue = tf.train.string_input_producer(filenames)
def read_cifar10(filename_queue):
label_bytes = 1
IMAGE_SIZE = 32
CHANNELS = 3
image_bytes = IMAGE_SIZE*IMAGE_SIZE*3
record_bytes = label_bytes+image_bytes
# **3、定义一个 reader。**
#若读取列表中为单独文件则用tf.WholeFileReader()
reader = tf.FixedLengthRecordReader(record_bytes)
key, value = reader.read(filename_queue)
record_bytes = tf.decode_raw(value, tf.uint8)
label = tf.strided_slice(record_bytes, [0], [label_bytes])
depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],
[label_bytes + image_bytes]),
[CHANNELS, IMAGE_SIZE, IMAGE_SIZE])
image = tf.transpose(depth_major, [1, 2, 0])
return image, label #tensor格式
定义一个reader,来读取固定长度的数据,这个固定长度是由CIFAR-10数据集图片的存储格式决定的,1byte的标签加上32 *32 *3长度的图像,3代表RGB三通道,由于图片的是按[channel, height, width]的格式存储的,为了变为常用的[height, width, channel]维度,需要在17行reshape一次图像,最终我们提取出了一副完整的图像与对应的标签。