本博文参考TensorFlow实战Google深度学习框架(郑泽宇,顾思宇),仅用作学习
一、TFRecord输入数据格式
TFRecord是tensorflow中存储数据的统一格式。可以统一不同的原始数据格式,并更加有效地管理不同的属性。TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。是一种可将图像的数据和标签放在一起的二进制文件,能节省内存,在TensorFlow中快速读取存储。
tf.train.Example的定义如下:
message Example{
Features Features=1;
};
message Features{
map<string,Feature> feature=1;
};
message Feature{
oneof kind{
ByteList bytes_list=1;
FloatList float_list=2;
Int64List int64_list=3;
}
};
tf.train.Example中包含了一个从属性名称到取值的字典。属性名称为字符串,属性的取值可以为字符串(ByteList),实数列表(FloatList)或整数列表(Int64List)。
从文件中读取数据一般分为:把样本数据写入TFRecords二进制文件,再从队列中读取。
1、生成TFRecord文件
需要将数据填到tf.train.Example的协议缓存区(Protocol Buffer)中,将协议缓存区序列化为一个字符串,通过tf.python_io.TFRecordWriter写入TFRecord文件中。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
mnist=input_data.read_data_sets("/path/to/mnist/data",dtype=tf.unit8,one_hot=True)
#训练数据
image=mnist.train.images
#训练数据所对应的的正确答案,可以作为一个属性保存在TFRecorde中
labels=mnist.train.labels
#训练数据的图像分辨率,可以作为Example中的一个属性
pixels=image.shape[1]
num_examples=mnist.train.num_examples
#生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
#生成字符串型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value