官网:
https://www.tensorflow.org/versions/r0.12/how_tos/reading_data/
有三种方式可以读取数据给TensorFlow
- 每一步都用python来提供数据(
feed_dict
) - 从文件读取 (很大的数据集不能一次载入内存,当然要从文件读取啦)
an input pipeline reads the data from files at the beginning of a TensorFlow graph. - 预加载(小数据)
读取文件
一般流程
1. 文件列表
2. (可选)文件名shuffle
3. (可选)epoch 限制
4. 文件名队列
5. 适合文件类型的Reader
6. 解码器
7. (可选)预处理
8. 示例队列
而在tf中,标准的文件格式就是tfrecords,是二进制编码的数据。
先来看看官网怎么说的:
tf.train.Example
定义数据格式:
The recommended format for TensorFlow is a TFRecords file containing
tf.train.Example
protocol buffers (which contain Features as a field).
tf.python_io.TFRecordWriter
写入tfrecords:
You write a little program that gets your data, stuffs it in an Example protocol buffer, serializes the protocol buffer to a string, and then writes the string to a TFRecords file using the
tf.python_io.TFRecordWriter
class.
示例:convert_to_records.py
用tf.parse_single_example
解码,tf.TFRecordReader
读取
To read a file of TFRecords, use
tf.TFRecordReader
with thetf.parse_single_example
decoder.
示例:fully_connected_reader.py
data IO:
https://www.tensorflow.org/versions/r0.12/api_docs/python/python_io/#tfrecords_format_details
class tf.python_io.TFRecordWriter
tf.python_io.TFRecordWriter.__init__(path, options=None)
path是写入tfrecord的路径tf.python_io.TFRecordWriter.write(record)
record :stringtf.python_io.TFRecordWriter.close()
结束写操作tf.python_io.TFRecordWriter.__enter__()
Enter awith
block.tf.python_io.TFRecordWriter.__exit__(unused_type, unused_value, unused_traceback)
tf.python_io.tf_record_iterator(path, options=None)
class tf.python_io.TFRecordCompressionType
class tf.python_io.TFRecordOptions
tf.python_io.TFRecordOptions.__init__(compression_type)
tf.python_io.TFRecordOptions.get_compression_type_string(cls, options)
write to TFRecords
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Converts MNIST data to TFRecords file format with Example protos."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets import mnist
FLAGS = None
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def convert_to(data_set, name):
"""Converts a dataset to tfrecords."""
images = data_set.images
labels = data_set.labels
num_examples = data_set.num_examples
if images.shape[0] != num_examples:
raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
depth = images.shape[3]
filename = os.path.join(FLAGS.directory, name + '.tfrecords')
print('Writing', filename)
with tf.python_io.TFRecordWriter(filename) as writer: #定义writer
for index in range(num_examples):
image_raw = images[index].tostring()
example = tf.train.Example(
features=tf.train.Features(
feature={
'height': _int64_feature(rows),
'width': _int64_feature(cols),
'depth': _int64_feature(depth),
'label': _int64_feature(int(labels[index])),
'image_raw': _bytes_feature(image_raw)
}))#定义数据格式
writer.write(example.SerializeToString()) #序列化之后写入
def main(unused_argv):
# Get the data.
data_sets = mnist.read_data_sets(FLAGS.directory,
dtype=tf.uint8,
reshape=False,
validation_size=FLAGS.validation_size)
# Convert to Examples and write the result to TFRecords.
convert_to(data_sets.train, 'train')
convert_to(data_sets.validation, 'validation')
convert_to(data_sets.test, 'test')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--directory',
type=str,
default='/tmp/data',
help='Directory to download data files and write the converted result'
)
parser.add_argument(
'--validation_size',
type=int,
default=5000,
help="""\
Number of examples to separate from the training data for the validation
set.\
"""
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
read TFRecords
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train and Eval the MNIST network.
This version is like fully_connected_feed.py but uses data converted
to a TFRecords file containing tf.train.Example protocol buffers.
See:
https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
for context.
YOU MUST run convert_to_records before running this (but you only need to
run it once).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
FLAGS = None
# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
def decode(serialized_example):
features = tf.parse_single_example( #解码
serialized_example,
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
# Convert from a scalar string tensor (whose single string has
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape((mnist.IMAGE_PIXELS))
# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)
return image, label
def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
return image, label
def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image, label
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
Args:
train: Selects between the training (True) and validation (False) data.
batch_size: Number of examples per returned batch.
num_epochs: Number of times to read the input data, or 0/None to
train forever.
Returns:
A tuple (images, labels), where:
* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
This function creates a one_shot_iterator, meaning that it will only iterate
over the dataset once. On the other hand there is no special initialization
required.
"""
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir,
TRAIN_FILE if train else VALIDATION_FILE)
with tf.name_scope('input'):
# TFRecordDataset opens a protobuf and reads entries line by line
# could also be [list, of, filenames]
dataset = tf.data.TFRecordDataset(filename)
dataset = dataset.repeat(num_epochs)
# map takes a python function and applies it to every sample
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
#the parameter is the queue size
dataset = dataset.shuffle(1000 + 3 * batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
def run_training():
"""Train MNIST for a number of steps."""
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(image_batch,
FLAGS.hidden1,
FLAGS.hidden2)
# Add to the Graph the loss calculation.
loss = mnist.loss(logits, label_batch)
# Add to the Graph operations that train the model.
train_op = mnist.training(loss, FLAGS.learning_rate)
# The op for initializing the variables.
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
# Create a session for running operations in the Graph.
with tf.Session() as sess:
# Initialize the variables (the trained variables and the
# epoch counter).
sess.run(init_op)
try:
step = 0
while True: #train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss])
duration = time.time() - start_time
# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
duration))
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
def main(_):
run_training()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
parser.add_argument(
'--num_epochs',
type=int,
default=2,
help='Number of epochs to run trainer.'
)
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Batch size.'
)
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/data',
help='Directory with the training data.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
http://warmspringwinds.github.io/tensorflow/tf-slim/2016/12/21/tfrecords-guide/
http://blog.csdn.net/u012759136/article/details/52232266
http://blog.csdn.net/matrix_space/article/details/60962026
https://github.com/ferreirafabio/video2tfrecords