参考了这篇博客的内容,做了些增加修改
TFRecord 是Google官方推荐的一种数据格式,是Google专门为TensorFlow设计的一种数据格式。
实际上,TFRecord是一种二进制文件,其能更好的利用内存,其内部包含了多个tf.train.Example, 而Example是protocol buffer(protobuf) 数据标准 [3] [4] 的实现,在一个Example消息体中包含了一系列的tf.train.feature属性,而 每一个feature 是一个key-value的键值对,其中,key 是string类型,而value 的取值有三种:
bytes_list: 可以存储string 和byte两种数据类型。
float_list: 可以存储float(float32)与double(float64) 两种数据类型 。
int64_list: 可以存储:bool, enum, int32, uint32, int64, uint64 。
tf.Example 类就是一种将数据表示为{‘string’: value}形式的 message类型,TensorFlow经常使用 tf.Example 来写入,读取 TFRecord数据。
通常情况下,tf.Example中可以使用以下几种格式:
tf.train.BytesList: 可以使用的类型包括 string和byte
tf.train.FloatList: 可以使用的类型包括 float和double
tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64
TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature
tf.train.Feature(bytes_list=tf.train.BytesList(value=[feature.tostring()]))
tf.train.Feature(int64_list=tf.train.Int64List(value=list(feature.shape)))
tf.train.Feature(float_list=tf.train.FloatList(value=[label]))
将数据转化为TFRecord文件
Google官方推荐在对于中大数据集来说,先将数据集转化为TFRecord数据(.tfrecords), 这样可加快你在数据读取, 预处理中的速度。
将数据转化为 tfrecord 格式只需要三步, 下面以三个features:context,question, answer为例 :
writer = tf.python_io.TFRecordWriter(out_file_name) # 1. 定义 writer对象
for data in dataes:
context = dataes[0]
question = dataes[1]
answer = dataes[2]
""" 2. 定义features """
example = tf.train.Example(
features = tf.train.Features(
feature = {
'context': tf.train.Feature(
int64_list=tf.train.Int64List(value=context)),
'question': tf.train.Feature(
int64_list=tf.train.Int64List(value=question)),
'answer': tf.train.Feature(
int64_list=tf.train.Int64List(value=answer))
}))
""" 3. 序列化,写入"""
serialized = example.SerializeToString()
writer.write(serialized)
将一张图片转化成TFRecord 文件
import tensorflow as tf
def write_test(input, output):
# 借助于TFRecordWriter 才能将信息写入TFRecord 文件
writer = tf.python_io.TFRecordWriter(output)
# 读取图片并进行解码
image = tf.read_file(input)
image = tf.image.decode_jpeg(image)
with tf.Session() as sess:
image = sess.run(image)
shape = image.shape
# 将图片转换成string
image_data = image.tostring()
print(type(image))
print(len(image_data))
name = bytes('cat', encoding='utf-8')
print(type(name))
# 创建Example对象,并将Feature一一对应填充进去
example = tf.train.Example(features=tf.train.Features(feature={
'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])),
'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image_data]))
}
))
# 将example序列化成string 类型,然后写入。
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
input_photo = 'cat.jpg'
output_file = 'cat.tfrecord'
write_test(input_photo, output_file)
TFRecord 文件读取为图片
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def _parse_record(example_photo):
features = {
'name': tf.FixedLenFeature((), tf.string),
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature((), tf.string)
}
parsed_features = tf.parse_single_example(example_photo,features=features)
return parsed_features
def read_test(input_file):
# 用dataset读取TFRecords文件
dataset = tf.data.TFRecordDataset(input_file)
dataset = dataset.map(_parse_record)
iterator = dataset.make_one_shot_iterator()
with tf.Session() as sess:
features = sess.run(iterator.get_next())
name = features['name']
name = name.decode()
img_data = features['data']
shape = features['shape']
print("==============")
print(type(shape))
print(len(img_data))
# 从bytes数组中加载图片原始数据,并重新reshape,它的结果是 ndarray 数组
img_data = np.fromstring(img_data, dtype=np.uint8)
image_data = np.reshape(img_data, shape)
plt.figure()
# 显示图片
plt.imshow(image_data)
plt.show()
# 将数据重新编码成jpg图片并保存
img = tf.image.encode_jpeg(image_data)
tf.gfile.GFile('cat_encode.jpg', 'wb').write(img.eval())
if __name__ == '__main__':
read_test("cat.tfrecord")
1,首先使用dataset去读取tfrecord文件
2,在解析example 的时候,用现成的API:tf.parse_single_example
3,用 np.fromstring() 方法就可以获取解析后的string数据,记得把数据还原成 np.uint8
4,用 tf.image.encode_jepg() 方法可以将图片数据编码成 jpeg 格式
5,用 tf.gfile.GFile 对象可以把图片数据保存到本地
6,因为将图片 shape 写入了example 中,所以解析的时候必须指定维度,在这里 [3],不然程序会报错。