原文地址: http://blog.csdn.net/u010911921/article/details/70577697
这段一直在用Tensorflow来做深度学习上的相关工作,然后对Tensorflow读取数据的方式进行实现。特地总结一下。首先是读取二进制图片数据,这里采用的是CIFAR-10的二进制数据
## 1.CIFAR-10数据集 CIFAR-10数据集合是包含60000张`32*32*3`的图片,其中每个类包含6000张图片,总共10类。在这60000张图片中50000张是训练集合,10000张是测试集合。其中二进制的图片保存的格式如下所示:
2.Tensorflow读取数据
从Tensorflow的官网可以看到从文件中读取数据的流程主要是一下步骤:
- The list of filenames
- (Optional) filename shuffling
- (Optional) epoch limit
- Filename queue
- A Reader for the file format
- A decoder for a record read by the reader
- (Optional) preprocessing
- Example queue
按照这样一个流程,首选应该将CIFAR-10的训练集和测试集合,生成文件名列表,然后在讲这个文件名列表传递给tf.train.string_input_producer
函数创建一个用于保存文件名称的FIFO的队列,最后用tensor flow产生的reader
从队列中读取数据。当reader
读到数据就需要用tf.decode_raw
函数对读取到的二进制数进行解码。
结束了上述操作,下面就需要采用另一个queue去batch together examples来为训练和测试提供数据。采用tf.train.shuffle_batch
将上面生成的image
和label
传入函数即可完成。
3.开始训练
当tf.train.shuffle_batch
生成batch以后就开始利用tf.train.start_queue_runners
函数启动队列,然后开始整个计算图,官网给的建议是如下形式:
init_op = tf.global_variables_initializer()
with tf.Session as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess= sess,coord = coord)
try:
while not coord.should_stop():
#run training steps or whatever
sess.run(train_op)
except tf.errors.OutOfRangeError:
print('Done training --epoch limit reached')
finally:
# when done,ask the threads to stop
coord.request_stop()
coord.join(threads)
4.代码实现
在神经网络的训练中由于每训练k步以后就会对网络进行一次测试,所以需要在上述步骤中,增加动态选择文件名称队列这样一个过程,可以由tf.QueueBase.from_list
函数进行实现,然后reader
从返回的文件名称队列中读取数据。
整个过程的实现如下所示:
#!/usr/bin/env python3
# --*-- encoding:utf-8 --*--
import tensorflow as tf
import numpy as np
import os
def read_cifar10(data_dir,is_traing,batch_size,shuffle):
"""
:param data_dir:数据保存路径
:param is_traing:True从训练集获取数据,False从测试集获取数据
:param batch_size: batch_size的大小
:param shuffle: bool,是否进行shuffle操作
:return:
"""
img_width = 32
img_height = 32
img_depth = 3
label_bytes = 1
img_bytes = img_height * img_width *img_depth
with tf.name_scope("input") as scope:
#训练集合的文件列表
train_filenames = [os.path.join(data_dir,
'data_batch_%d.bin'%ii) for ii in np.arange(1,6)]
#测试集合的文件列表
val_filenames = [os.path.join(data_dir,'test_batch.bin')]
#训练集和测试集合的文件名称队列
train_queue = tf.train.string_input_producer(train_filenames)
val_queue = tf.train.string_input_producer(val_filenames)
#挑选文件队列,实现training的过程中测试
queue_select = tf.cond(is_traing,
lambda :tf.constant(0),
lambda :tf.constant(1) )
queue = tf.QueueBase.from_list(queue_select,[train_queue,val_queue])
#从队列中读取固定长度的数据
reader = tf.FixedLengthRecordReader(label_bytes+img_bytes)
key,value = reader.read(queue)
recode_bytes = tf.decode_raw(value,tf.uint8)
#获取label
label = tf.slice(recode_bytes,[0],[label_bytes])
label = tf.cast(label,tf.int32)
#获取image
image_raw = tf.slice(recode_bytes,[label_bytes],[img_bytes])
image_raw = tf.reshape(image_raw,[img_depth, img_height, img_width])
image = tf.transpose(image_raw,[1,2,0])
image = tf.cast(image,tf.float32)
#对每一张图片进行标准化操作,可选操作此处可以进行对图片的各种操作
image = tf.image.per_image_standardization(image)
if shuffle:
images, label_batch= tf.train.shuffle_batch([image,label],
batch_size=batch_size,
num_threads=16,
capacity=512+3*batch_size,
min_after_dequeue=512,
allow_smaller_final_batch=True)
else:
images, label_batch = tf.train.batch([image, label],
batch_size=batch_size,
num_threads=16,
capacity=512 + 3*batch_size,
allow_smaller_final_batch=True)
label_batch = tf.cast(label_batch,tf.int32)
return images,label_batch
整个过程是采用VGG-16的网络模型进行训练的,在迭代16000次,tensorboard展示的结果如图所示:
code下载地址https://github.com/ZhichengHuang/LearnTensorflowCode