有了tfrecord数据集,自然就要从数据集中把原始训练数据解出来进行训练,tensorflow提供了一整套方法来处理tfrecord数据集的读取,包括读取函数和多线程处理数据的方法。
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def read_and_decode(filename):
#"C:\\Python34\\tensorflow\\tfrecord_train_plane_1.tfrecords"
filename_queue = tf.train.string_input_producer([filename], shuffle = False) #使用初始化时提供的文件列表创建一个输入队列,输入队列中原始的元素为文件列表中的所有文件
reader = tf.TFRecordReader()#创建一个reader来读取TFRecord文件中的样例
_, serialized_example = reader.read(filename_queue) #从文件中读出一个样例,返回文件名和文件
#batch = tf.train.batch(tensors=[serialized_example],batch_size=3)
#解析读入的一个样例
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([3], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
}) #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.uint8)#采用decode_raw将字符串解析成图像对应的像素数组
image = tf.reshape(image, [128, 128, 1])#将读入的数据重新整理为128*128的图像,1为图像的通道数
image = (image-tf.reduce_min(image))/(tf.reduce_max(image)-tf.reduce_min(image))#将图像数据归一化
label = tf.cast(features['label'], tf.float32)#将标签数据改为实数型
return image, label
if __name__=="__main__":
img, label = read_and_decode("C:\\tensorflowprogram\\tensorflow\\tfrecord_train_plane\\tfrecord_train_plane_128_330.tfrecords")
count = 0
#采用tf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出,[img,label]给出了需要组合的元素,batch_size为每次出队得到的样例数量,capacity给出了队列的最大容量,min_after_dequeue限制了出队时最少元素的个数
img_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size=20, capacity=1060,min_after_dequeue = 30)
with tf.Session() as sess: #开始一个会话
init_op = tf.global_variables_initializer()
sess.run(init_op)
coord=tf.train.Coordinator()#声明tf.train.Coordinator来协同不同的进程,并启动线程
threads= tf.train.start_queue_runners(coord=coord)
#测试程序,测试程序是否能批量读入数据
for i in range (15):
k,l=sess.run([img_batch,label_batch])
print(type(k))
print(k)
print(l)
#print(type(l))
print(i)
coord.request_stop()
coord.join(threads)
经过测试,该程序合格,能够不间断地读取数据,用来输入神经网络模型进行训练。