TFRecord格式
TFRecord格式可以统一不同的原始数据格式,并更加有效地管理不同的属性
格式介绍
其数据通过tf.train.Example Protocol Buffer格式存储:
message Example{
Feature features = 1;
};
message Features{
map<string, Feature> feature = 1;
};
message Feature{
oneof kind{
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
;}
字典:属性名称string——>取值Feature
属性名称:字符串
属性取值:字符串、实数列表、整数列表
样例程序
- TFRecord写操作:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
mnist = input_data.read_data_sets("", dtype=tf.uint8, one_hot=True)
images = mnist.train.images
labels = mnist.train.labels
pixels = images.shape[1]
num_examples = mnist.train.num_examples
filename = ""
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):
# 图像矩阵转化为一个字符串
image_raw = images[index].toString()
example = tf.train.Example(features=tf.train.Feature(feature={
'pixels': _int64_feature(pixels),
'labels': _int64_feature(np.argmax(labels[index])),
'image_raw': _bytes_feature(image_raw)
}))
writer.write(example.SerializeToString())
writer.close()
- TFRecord读操作:
import tensorflow as tf
reader = tf.TFRecordReader()
# 队列维护输入文件列表
filename_queue = tf.train.string_input_producer([""])
# 从文件中读取一个样例
_, serialized_example = reader.read(filename_queue)
# 解析读入的一个样例
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'pixels': tf.FixedLenFeature([], tf.int64),
'labels': tf.FixedLenFeature([], tf.int64),
}
)
# 将字符串解析为对应的像素数组
image = tf.decode_raw(features['image_raw'], tf.uint8)
lable = tf.cast(features['labels'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)
sess = tf.Session()
# 启动多线程处理数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(10):
print(sess.run([image, lable, pixels]))
多线程输入数据处理框架
经典的输入数据处理流程:
其中 队列 是Tensorflow多线程输入数据处理框架的基础。
队列
队列与变量类似,修改队列状态的操作主要有Enqueue、EnqueueMany、Dequeue。
TensorFlow中提供两种队列:FIFOQueue(先进先出队列)和RandomShuffleQueue(元素打乱队列,每次从队列中随机选择一个出队,在神经网络中希望每次使用的训练数据尽量随机,所以常用此队列)
操作队列例子:
import tensorflow as tf
# 先进先出队列创建(队列中最多可保存两个元素)
q = tf.FIFOQueue(2, "int32")
# 队列初始化(注意逗号不要忘记,这样既是列表,又满足tf接收的是tensor张量的需求)
init = q.enqueue_many(([0, 10],))
# 出队
x = q.dequeue()
y = x + 1
# 入队
q_inc = q.enqueue([y])
with tf.Session() as sess:
init.run()
for _ in range(5):
v, _ = sess.run([x, q_inc])
print(v)
队列不仅是一种数据结构,还是异步计算张量的一个重要机制。多线程可同时向一个队列中写元素或读元素。
多线程
TensorFlow提供两个类完成多线程协同功能:tf.Coordinator(协同多个线程一起停止)和tf.QueueRunner(启动多个线程来操作同一队列)
import tensorflow as tf
import numpy as np
import threading
import time
def MyLoop(coord, work_id):
while not coord.should_stop():
if np.random.rand() < 0.1:
print("stop from id: %d" % work_id)
coord.request_stop()
else:
print("working on id: %d" % work_id)
time.sleep(1)
# 声明类
coord = tf.train.Coordinator()
# 创建5个线程(将coord类传入每个线程中)
threads = [
threading.Thread(target=MyLoop, args=(coord, i, )) for i in range(5)
]
# 启动所有线程
for t in threads:
t.start()
# 等待所有线程退出
coord.join(threads)
综合TFRecord和队列——队列管理
tf.train.match_filenames_once函数:获取符合正则表达式的所有文件
tf.train.string_input_producer函数:对文件列表进行有效管理
注意:当一个输入队列中的所有文件都被处理完后,会将初始化时提供的文件列表中的文件全部重新加入队列。num_epochs参数限制加载初始文件列表的最大轮数。
import tensorflow as tf
files = tf.train.match_filenames_once("./data.tfrecords-*")
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'i': tf.FixedLenFeature([], tf.int64),
'j': tf.FixedLenFeature([], tf.int64)
})
with tf.Session() as sess:
# 使用tf.train.match_filename_once()函数时需要初始化一些变量
tf.local_variables_initializer().run()
# 启动线程
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(6):
print(sess.run([features['i'], features['j']]))
# 停止所有线程
coord.request_stop()
coord.join(threads)
注意:TensorFlow提供了一套更高层的数据处理框架——“数据集”,tf.data核心组件。
并正式推荐使用数据集作为输入数据的首选框架。