本例代码:https://github.com/yongyehuang/Tensorflow-Tutorial/tree/master/python/the_use_of_tfrecord
关于 tfrecord 的使用,分别介绍 tfrecord 进行三种不同类型数据的处理方法。
- 维度固定的 numpy 矩阵
- 可变长度的 序列 数据
- 图片数据
在 tf1.3 及以后版本中,推出了新的 Dataset API, 之前赶实验还没研究,可能以后都不太会用下面的方式写了。这些代码都是之前写好的,因为注释中都写得比较清楚了,所以直接上代码。
tfrecord_2_sequence_writer.py
# -*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
from tqdm import tqdm
'''tfrecord 写入序列数据,每个样本的长度不固定。
和固定 shape 的数据处理方式类似,前者使用 tf.train.Example() 方式,而对于变长序列数据,需要使用
tf.train.SequenceExample()。 在 tf.train.SequenceExample() 中,又包括了两部分:
context 来放置非序列化部分;
feature_lists 放置变长序列。
refer:
https://github.com/tensorflow/magenta/blob/master/magenta/common/sequence_example_lib.py
https://github.com/dennybritz/tf-rnn
http://leix.me/2017/01/09/tensorflow-practical-guides/
https://github.com/siavash9000/im2txt_demo/blob/master/im2txt/im2txt/ops/inputs.py
'''
# **1.创建文件
writer1 = tf.python_io.TFRecordWriter('../../data/seq_test1.tfrecord')
writer2 = tf.python_io.TFRecordWriter('../../data/seq_test2.tfrecord')
# 非序列数据
labels = [1, 2, 3, 4, 5, 1, 2, 3, 4]
# 长度不固定的序列
frames = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]
writer = writer1
for i in tqdm(xrange(len(labels))): # **2.对于每个样本
if i == len(labels) / 2:
writer = writer2
print('\nThere are %d sample writen into writer1' % i)
label = labels[i]
frame = frames[i]
# 非序列化
label_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
# 序列化
frame_feature = [
tf.train.Feature(int64_list=tf.train.Int64List(value=[frame_])) for frame_ in frame
]
seq_example = tf.train.SequenceExample(
# context 来放置非序列化部分
context=tf.train.Features(feature={
"label": label_feature
}),
# feature_lists 放置变长序列
feature_lists=tf.train.FeatureLists(feature_list={
"frame": tf.train.FeatureList(feature=frame_feature),
})
)
serialized = seq_example.SerializeToString()
writer.write(serialized) # **4.写入文件中
print('Finished.')
writer1.close()
writer2.close()
tfrecord_2_sequence_reader.py
# -*- coding:utf-8 -*-
import tensorflow as tf
import math
QUEUE_CAPACITY = 100
SHUFFLE_MIN_AFTER_DEQUEUE = QUEUE_CAPACITY // 5
"""
读取变长序列数据。
和固定shape的数据读取方式不一样,在读取变长序列中,我们无法使用 tf.train.shuffle_batch() 函数,只能使用
tf.train.batch() 函数进行读取,而且,在读取的时候,必须设置 dynamic_pad 参数为 True, 把所有的序列 padding
到固定长度(该batch中最长的序列长度),padding部分为 0。
此外,在训练的时候为了实现 shuffle 功能,我们可以使用 RandomShuffleQueue 队列来完成。详见下面的 _shuffle_inputs 函数。
"""
def _shuffle_inputs(input_tensors, capacity, min_after_dequeue, num_threads):
"""Shuffles tensors in `input_tensors`, maintaining grouping."""
shuffle_queue = tf.RandomShuffleQueue(
capacity, min_after_dequeue, dtypes=[t.dtype for t in input_tensors])
enqueue_op = shuffle_queue.enqueue(input_tensors)
runner = tf.train.QueueRunner(shuffle_queue, [enqueue_op] * num_threads)
tf.train.add_queue_runner(runner)
output_tensors = shuffle_queue.dequeue()
for i in range(len(input_tensors)):
output_tensors[i].set_shape(input_tensors[i].shape)
return output_tensors
def get_padded_batch(file_list, batch_size, num_enqueuing_threads=4, shuffle=False):
"""Reads batches of SequenceExamples from TFRecords and pads them.
Can deal with variable length SequenceExamples by padding each batch to the
length of the longest sequence with zeros.
Args:
file_list: A list of paths to TFRecord files containing SequenceExamples.
batch_size: The number of SequenceExamples to include in each batch.
num_enqueuing_threads: The number of threads to use for enqueuing
SequenceExamples.
shuffle: Whether to shuffle the batches.
Returns:
labels: A tensor of shape [batch_size] of int64s.
frames: A tensor of shape [batch_size, num_steps] of floats32s. note that
num_steps is the max time_step of all the tensors.
Raises:
ValueError: If `shuffle` is True and `num_enqueuing_threads` is less than 2.
"""
file_queue = tf.train.string_input_producer(file_list)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
context_features = {
"label": tf.FixedLenFeature([], dtype=tf.int64)
}
sequence_features = {
"frame": tf.FixedLenSequenceFeature([], dtype=tf.int64)
}
context_parsed, sequence_parsed = tf.parse_single_sequence_example(
serialized=serialized_example,
context_features=context_features,
sequence_features=sequence_features
)
labels = context_parsed['label']
frames = sequence_parsed['frame']
input_tensors = [labels, frames]
if shuffle:
if num_enqueuing_threads < 2:
raise ValueError(
'`num_enqueuing_threads` must be at least 2 when shuffling.')
shuffle_threads = int(math.ceil(num_enqueuing_threads) / 2.)
# Since there may be fewer records than SHUFFLE_MIN_AFTER_DEQUEUE, take the
# minimum of that number and the number of records.
min_after_dequeue = count_records(
file_list, stop_at=SHUFFLE_MIN_AFTER_DEQUEUE)
input_tensors = _shuffle_inputs(
input_tensors, capacity=QUEUE_CAPACITY,
min_after_dequeue=min_after_dequeue,
num_threads=shuffle_threads)
num_enqueuing_threads -= shuffle_threads
tf.logging.info(input_tensors)
return tf.train.batch(
input_tensors,
batch_size=batch_size,
capacity=QUEUE_CAPACITY,
num_threads=num_enqueuing_threads,
dynamic_pad=True,
allow_smaller_final_batch=False)
def count_records(file_list, stop_at=None):
"""Counts number of records in files from `file_list` up to `stop_at`.
Args:
file_list: List of TFRecord files to count records in.
stop_at: Optional number of records to stop counting at.
Returns:
Integer number of records in files from `file_list` up to `stop_at`.
"""
num_records = 0
for tfrecord_file in file_list:
tf.logging.info('Counting records in %s.', tfrecord_file)
for _ in tf.python_io.tf_record_iterator(tfrecord_file):
num_records += 1
if stop_at and num_records >= stop_at:
tf.logging.info('Number of records is at least %d.', num_records)
return num_records
tf.logging.info('Total records: %d', num_records)
return num_records
if __name__ == '__main__':
tfrecord_file_names = ['../../data/seq_test1.tfrecord', '../../data/seq_test2.tfrecord']
label_batch, frame_batch = get_padded_batch(tfrecord_file_names, 10, shuffle=True)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
tf.train.start_queue_runners(sess=sess)
for i in xrange(3):
_frames_batch, _label_batch = sess.run([frame_batch, label_batch])
print('** batch %d' % i)
print(_label_batch)
print(_frames_batch)
sequence_example_lib.py
# -*- coding:utf-8 -*-
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for working with tf.train.SequenceExamples.
https://github.com/tensorflow/magenta/blob/master/magenta/common/sequence_example_lib.py
"""
import math
import tensorflow as tf
QUEUE_CAPACITY = 500
SHUFFLE_MIN_AFTER_DEQUEUE = QUEUE_CAPACITY // 5
def make_sequence_example(inputs, labels):
"""Returns a SequenceExample for the given inputs and labels.
Args:
inputs: A list of input vectors. Each input vector is a list of floats.
labels: A list of ints.
Returns:
A tf.train.SequenceExample containing inputs and labels.
"""
input_features = [
tf.train.Feature(float_list=tf.train.FloatList(value=input_))
for input_ in inputs]
label_features = [
tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
for label in labels]
feature_list = {
'inputs': tf.train.FeatureList(feature=input_features),
'labels': tf.train.FeatureList(feature=label_features)
}
feature_lists = tf.train.FeatureLists(feature_list=feature_list)
return tf.train.SequenceExample(feature_lists=feature_lists)
def _shuffle_inputs(input_tensors, capacity, min_after_dequeue, num_threads):
"""Shuffles tensors in `input_tensors`, maintaining grouping."""
shuffle_queue = tf.RandomShuffleQueue(
capacity, min_after_dequeue, dtypes=[t.dtype for t in input_tensors])
enqueue_op = shuffle_queue.enqueue(input_tensors)
runner = tf.train.QueueRunner(shuffle_queue, [enqueue_op] * num_threads)
tf.train.add_queue_runner(runner)
output_tensors = shuffle_queue.dequeue()
for i in range(len(input_tensors)):
output_tensors[i].set_shape(input_tensors[i].shape)
return output_tensors
def get_padded_batch(file_list, batch_size, input_size,
num_enqueuing_threads=4, shuffle=False):
"""Reads batches of SequenceExamples from TFRecords and pads them.
Can deal with variable length SequenceExamples by padding each batch to the
length of the longest sequence with zeros.
Args:
file_list: A list of paths to TFRecord files containing SequenceExamples.
batch_size: The number of SequenceExamples to include in each batch.
input_size: The size of each input vector. The returned batch of inputs
will have a shape [batch_size, num_steps, input_size].
num_enqueuing_threads: The number of threads to use for enqueuing
SequenceExamples.
shuffle: Whether to shuffle the batches.
Returns:
inputs: A tensor of shape [batch_size, num_steps, input_size] of floats32s.
labels: A tensor of shape [batch_size, num_steps] of int64s.
lengths: A tensor of shape [batch_size] of int32s. The lengths of each
SequenceExample before padding.
Raises:
ValueError: If `shuffle` is True and `num_enqueuing_threads` is less than 2.
"""
file_queue = tf.train.string_input_producer(file_list)
reader = tf.TFRecordReader()
_, serialized_example = reader.read(file_queue)
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[input_size],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}
_, sequence = tf.parse_single_sequence_example(
serialized_example, sequence_features=sequence_features)
length = tf.shape(sequence['inputs'])[0] # 序列长度
input_tensors = [sequence['inputs'], sequence['labels'], length]
if shuffle:
if num_enqueuing_threads < 2:
raise ValueError(
'`num_enqueuing_threads` must be at least 2 when shuffling.')
shuffle_threads = int(math.ceil(num_enqueuing_threads) / 2.)
# Since there may be fewer records than SHUFFLE_MIN_AFTER_DEQUEUE, take the
# minimum of that number and the number of records.
min_after_dequeue = count_records(
file_list, stop_at=SHUFFLE_MIN_AFTER_DEQUEUE)
input_tensors = _shuffle_inputs(
input_tensors, capacity=QUEUE_CAPACITY,
min_after_dequeue=min_after_dequeue,
num_threads=shuffle_threads)
num_enqueuing_threads -= shuffle_threads
tf.logging.info(input_tensors)
return tf.train.batch(
input_tensors,
batch_size=batch_size,
capacity=QUEUE_CAPACITY,
num_threads=num_enqueuing_threads,
dynamic_pad=True,
allow_smaller_final_batch=False)
def count_records(file_list, stop_at=None):
"""Counts number of records in files from `file_list` up to `stop_at`.
Args:
file_list: List of TFRecord files to count records in.
stop_at: Optional number of records to stop counting at.
Returns:
Integer number of records in files from `file_list` up to `stop_at`.
"""
num_records = 0
for tfrecord_file in file_list:
tf.logging.info('Counting records in %s.', tfrecord_file)
for _ in tf.python_io.tf_record_iterator(tfrecord_file):
num_records += 1
if stop_at and num_records >= stop_at:
tf.logging.info('Number of records is at least %d.', num_records)
return num_records
tf.logging.info('Total records: %d', num_records)
return num_records
def flatten_maybe_padded_sequences(maybe_padded_sequences, lengths=None):
"""Flattens the batch of sequences, removing padding (if applicable).
Args:
maybe_padded_sequences: A tensor of possibly padded sequences to flatten,
sized `[N, M, ...]` where M = max(lengths).
lengths: Optional length of each sequence, sized `[N]`. If None, assumes no
padding.
Returns:
flatten_maybe_padded_sequences: The flattened sequence tensor, sized
`[sum(lengths), ...]`.
"""
def flatten_unpadded_sequences():
# The sequences are equal length, so we should just flatten over the first
# two dimensions.
return tf.reshape(maybe_padded_sequences,
[-1] + maybe_padded_sequences.shape.as_list()[2:])
if lengths is None:
return flatten_unpadded_sequences()
def flatten_padded_sequences():
indices = tf.where(tf.sequence_mask(lengths))
return tf.gather_nd(maybe_padded_sequences, indices)
return tf.cond(
tf.equal(tf.reduce_min(lengths), tf.shape(maybe_padded_sequences)[1]),
flatten_unpadded_sequences,
flatten_padded_sequences)