本文简单的介绍了使用tensorflow处理TFRecords文件。
文章目录
TFRecords文件简介
TFRecords是文件存储的有结构的序列化字符块,它是tensorflow推荐的标准文件格式。
创建TFRecords文件
整体流程
- 创建一个TFRecordWriter
- 通过tf.train.Example创建一个example(内部需要定义features方法)
- 将example序列化写入文件
- 关闭文件流
具体代码
# -*- coding: UTF-8 -*-
import tensorflow as tf
'''write.py'''
writer = tf.python_io.TFRecordWriter('data/stat.tfrecord')
for i in range(1, 3):
example = tf.train.Example(
features = tf.train.Features(
feature = {
'id': tf.train.Feature(int64_list=tf.train.Int64List(value=[i])),
'age': tf.train.Feature(int64_list=tf.train.Int64List(value=[i*24])),
'income': tf.train.Feature(float_list=tf.train.FloatList(value=[i*2048.0])),
'outgo': tf.train.Feature(float_list=tf.train.FloatList(value=[i*1024.0]))
}
)
)
writer.write(example.SerializeToString())
# 关闭输出流
writer.close()
读取TFRecords文件
整体流程
- 创建文件名队列
- 创建TFRecordReader,并调用read方法
- Decoder
- 启动进程
具体代码
# -*- coding: UTF-8 -*-
import tensorflow as tf
'''reader.py'''
# 创建文件队列
filename_queue = tf.train.string_input_producer(['data/stat.tfrecord'])
# 创建读取TFRecords文件的reader
reader = tf.TFRecordReader()
# 读,去除stat.tfrecords文件中的一条序列化的样例serialized_example
_, serialized_example = reader.read(filename_queue)
# Decoder 将一条序列化的的样例转换为包含所有特征的张量
features = tf.parse_single_example(
serialized_example,
features={
'id': tf.FixedLenFeature([], tf.int64),
'age': tf.FixedLenFeature([], tf.int64),
'income': tf.FixedLenFeature([], tf.float32),
'outgo': tf.FixedLenFeature([], tf.float32)
}
)
# print(features)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess) # 启动执行入队操作的后台进程
for i in range(2):
example = sess.run(features)
print(example)
工程应用中使用协调器
队列越界异常的产生
队列越界异常即是读取数据次数超过了队列的长度,错误代码如下,报错ERROR:tensorflow:Exception in QueueRunner
import tensorflow as tf
# 输出的队列异常
# 创建文件名队列filename_queue,并制定遍历两次数据集
filename_queue = tf.train.string_input_producer(['data/stat.tfrecord'], num_epochs=2)
# 创建一个reader
reader = tf.TFRecordReader()
# 通过reader去读数据
_, serialized_example = reader.read(filename_queue)
# 进行Decoder,将序列化的样例转换为其包含的所有张量
features = tf.parse_single_example(
serialized_example,
features={
'id': tf.FixedLenFeature([], tf.int64),
'age': tf.FixedLenFeature([], tf.int64),
'income': tf.FixedLenFeature([], tf.float32),
'outgo': tf.FixedLenFeature([], tf.float32)
}
)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
for i in range(2):
example = sess.run(features)
print(example)
运行结果
引入多线程生命周期协调器
通过协调器来监控后台线程。
import tensorflow as tf
# 创建文件队列filename_queue,遍历两次数据集
filename_queue = tf.train.string_input_producer(['data/stat.tfrecord'], num_epochs=2)
# 创建reader
reader = tf.TFRecordReader()
# 读文件队列
_, serialized_example = reader.read(filename_queue)
# Decoder
features = tf.parse_single_example(
serialized_example,
features={
'id': tf.FixedLenFeature([], tf.int64),
'age': tf.FixedLenFeature([], tf.int64),
'income': tf.FixedLenFeature([], tf.float32),
'outgo': tf.FixedLenFeature([], tf.float32)
}
)
with tf.Session() as sess:
# 聚合两种初始化方法
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
# 创建协调器
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
for i in range(10):
if not coord.should_stop():
example = sess.run(features)
print(example)
except tf.errors.OutOfRangeError:
print('Catch OutOfRangeError')
finally:
# 请求停止
coord.request_stop()
print('Finish reading')
# 等待所有后台进程安全退出
coord.join(threads)
sess.close()
运行结果
参考资料
《深入理解Tensorflow架构设计与实现原理》