一、输入流水线读取数据流程
1). 创建文件名列表
相关函数:tf.train.match_filenames_once
2). 创建文件名队列
相关函数:tf.train.string_input_producer
3). 创建Reader读取数据
tf.ReaderBase 、 tf.TFRecordReader 、 tf.TextLineReader 、 tf.WholeFileReader 、 tf.IdentityReader 、 tf.FixedLengthRecordReader …
4).创建decoder解码器转换格式
tf.decode_csv 、 tf.decode_raw 、 tf.image.decode_image …
5). 创建样例队列
相关函数:tf.train.shuffle_batch
二、常用Reader、decoder介绍
CSV文件读取
阅读器:tf.TextLineReader
解析器:tf.decode_csv
二进制文件读取
阅读器:tf.FixedLengthRecordReader
解析器:tf.decode_raw
图像文件读取
阅读器:tf.WholeFileReader
解析器:tf.image.decode_image, tf.image.decode_gif, tf.image.decode_jpeg, tf.image.decode_png
TFRecords文件读取
阅读器:tf.TFRecordReader
解析器:tf.parse_single_example
又或者使用slim提供的简便方法:slim.dataset.Dataset以及slim.dataset_data_provider.DatasetDataProvider方法,一般slim.dataset.Dataset作为函数返回,需要接收Reader和Decoder作为参数。
def get_split(record_file_name, num_sampels, size):
reader = tf.TFRecordReader
keys_to_features = {
"image/encoded": tf.FixedLenFeature((), tf.string, ''),
"image/format": tf.FixedLenFeature((), tf.string, 'jpeg'),
"image/height": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)),
"image/width": tf.FixedLenFeature([], tf.int64, tf.zeros([], tf.int64)),
}
items_to_handlers = {
"image": slim.tfexample_decoder.Image(shape=[size, size, 3]),
"height": slim.tfexample_decoder.Tensor("image/height"),
"width": slim.tfexample_decoder.Tensor("image/width"),
}
decoder = slim.tfexample_decoder.TFExampleDecoder(
keys_to_features, items_to_handlers
)
return slim.dataset.Dataset(
data_sources=record_file_name,
reader=reader,
decoder=decoder,
items_to_descriptions={},
num_samples=num_sampels
)
def get_image(num_samples, resize, record_file="image.tfrecord", shuffle=False):
provider = slim.dataset_data_provider.DatasetDataProvider(
get_split(record_file, num_samples, resize), # slim.dataset.Dataset 做参数
shuffle=shuffle
)
[data_image] = provider.get(["image"]) # Provider通过TFR字段获取batch size数据
return data_image
三、以图片文件为例
filename_queue = tf.train.string_input_producer(filenames,
shuffle=shuffle, num_epochs=epochs)
reader = tf.WholeFileReader()
_, img_bytes = reader.read(filename_queue)
image = tf.image.decode_png(img_bytes, channels=3)
if png else tf.image.decode_jpeg(img_bytes, channels=3)
1.建立文件名队列
filename_queue = tf.train.string_input_producer(filenames)
2.阅读器初始化 & 单次读取规则设定
# 初始化阅读器
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
# 指定被阅读文件
result.key, value = reader.read(filename_queue)
3.单次读取数据decode
# Convert from a string to a vector of uint8 that is record_bytes long.
# read出来的是一个二进制的string,将它解码依照uint8格式解码
record_bytes = tf.decode_raw(value, tf.uint8)
…… ……
由于读取来的tensor不具有静态shape,需要使用tensor.set_shape()指定shape(或者在处理中显示的赋予shape如使用reshape等函数),否则无法建立图
read_input.label.set_shape([1])
4.输入入网络
将最后的规则tensor传入batch生成池节点中,输出的张量可以直接feed进网络
images_train, labels_train = cifar10_input.distorted_inputs(data_dir=data_dir,
batch_size=batch_size)
…… ……
image_batch, label_batch = sess.run([images_train, labels_train])
_, loss_value = sess.run(
[train_op, loss],
feed_dict={image_holder:image_batch, label_holder:label_batch})
5.初始化队列(相关的线程控制器组件添加也在这里)
# 启动数据增强队列
tf.train.start_queue_runners()
附上线程控制组件使用示意,
import tensorflow as tf
sess = tf.Session()
coord = tf.train.coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
# 训练过程
coord.request_stop()
coord.join(threads)