概要
TFRecord是一种存储二进制记录数据的简单格式。协议缓存是一个有效序列化结构化数据的跨平台、跨语言库,协议信息被.proto文件定义,它们常常是理解信息类型最简单的方法。
tf.train.Example
信息是一个灵活的信息类型,用于表达{"string": "value"}
对。它被设计和TensorFlow一起使用,并通过高级APIs使用,如TFX。
本文将介绍如何创建、解析、使用tf.train.Example
信息,然后序列化、读写tf.train.Example
信息通过.tfrecord
文件。
内容
TFRecord文件包含一个序列记录,文件只能被按顺序读取。TFRecord文件并不是必须使用tf.data.Example
。tf.data.Example
仅仅只是一个序列化字典为字节串的方法。任何字节串在TensorFlow中都能被解码并存储在TFRecord文件中,包括:文本中的行数据、json数据(tf.io.decode_json_example
)、被编码的图像数据、被序列化的tf.Tensors
(tf.io.serialize_tensor/tf.io.parse_tensor
)。
tf.data
模块可以提供读写数据的工具。
写一个TFRecord文件
将数据形成数据集最简单的方法就是使用from_tensor_slices
方法。
使用tf.data.Dataset.map
方法对数据集中的样本进行操作(类似pandas的apply)。
# The number of observations in the dataset.
n_observations = int(1e4)
# Boolean feature, encoded as False or True.
feature0 = np.random.choice([False, True], n_observations)
# Integer feature, random from 0 to 4.
feature1 = np.random.randint(0, 5, n_observations)
# String feature.
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# Float feature, from a standard normal distribution.
feature3 = np.random.randn(n_observations)
# 参数是数组,返回标量数据集
tf.data.Dataset.from_tensor_slices(feature1)
# 参数是元组,返回元组数据集
features_dataset = tf.data.Dataset.from_tensor_slices((feature0, feature1, feature2, feature3))
# 查看数据集中的数据
for f0,f1,f2,f3 in features_dataset.take(1):
print(f0)
print(f1)
print(f2)
print(f3)
# 使用tf.data.Dataset.map对数据中的样本进行序列化
# 定义一个序列化函数
def serialize_example(feature0, feature1, feature2, feature3):
"""
Creates a tf.train.Example message ready to be written to a file.
"""
# Create a dictionary mapping the feature name to the tf.train.Example-compatible
# data type.
feature = {
'feature0': _int64_feature(feature0),
'feature1': _int64_feature(feature1),
'feature2': _bytes_feature(feature2),
'feature3': _float_feature(feature3),
}
# Create a Features message using tf.train.Example.
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
# 使用tf.py_function包装成map接受
def tf_serialize_example(f0,f1,f2,f3):
tf_string = tf.py_function(
serialize_example,
(f0, f1, f2, f3), # Pass these args to the above function.
tf.string) # The return type is `tf.string`.
return tf.reshape(tf_string, ()) # The result is a scalar.
# 使用map对数据进行序列化
serialized_features_dataset = features_dataset.map(tf_serialize_example)
# 也可以使用tf.data.Dataset.from_generator进行操作
# 需要封装一个迭代生成器
def generator():
for features in features_dataset:
yield serialize_example(*features)
# 这个结果和上面的map序列化的结构一致。
serialized_features_dataset = tf.data.Dataset.from_generator(generator, output_types=tf.string, output_shapes=())
# 写成一个tfrecord文件
filename = "test.tfrecord"
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(serialized_features_dataset)
读一个TFRecord文件
使用tf.data.TFRecordDataset
可以读取TFRecord文件
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(10):
print(repr(raw_record))
# Create a description of the features.
feature_description = {
'feature0': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature1': tf.io.FixedLenFeature([], tf.int64, default_value=0),
'feature2': tf.io.FixedLenFeature([], tf.string, default_value=''),
'feature3': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
}
def _parse_function(example_proto):
# Parse the input `tf.train.Example` proto using the dictionary above.
return tf.io.parse_single_example(example_proto, feature_description)
parsed_dataset = raw_dataset.map(_parse_function)
for parsed_record in parsed_dataset.take(10):
print(repr(parsed_record))
使用tf.io
模块对TFRecord文件进行读写
# Write the `tf.train.Example` observations to the file.
with tf.io.TFRecordWriter(filename) as writer:
for i in range(n_observations):
example = serialize_example(feature0[i], feature1[i], feature2[i], feature3[i])
writer.write(example)
filenames = [filename]
raw_dataset = tf.data.TFRecordDataset(filenames)
for raw_record in raw_dataset.take(1):
example = tf.train.Example()
example.ParseFromString(raw_record.numpy())
print(example)
result = {}
# example.features.feature is the dictionary
for key, feature in example.features.feature.items():
# The values are the Feature objects which contain a `kind` which contains:
# one of three fields: bytes_list, float_list, int64_list
kind = feature.WhichOneof('kind')
result[key] = np.array(getattr(feature, kind).value)
图像数据的TFRecord示例
# 加载图像
cat_in_snow = tf.keras.utils.get_file(
'320px-Felis_catus-cat_on_snow.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg')
williamsburg_bridge = tf.keras.utils.get_file(
'194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg',
'https://storage.googleapis.com/download.tensorflow.org/example_images/194px-New_East_River_Bridge_from_Brooklyn_det.4a09796u.jpg')
# 展示图像
display.display(display.Image(filename=cat_in_snow))
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
# write tfrecord
image_labels = {
cat_in_snow : 0,
williamsburg_bridge : 1,
}
# This is an example, just using the cat image.
image_string = open(cat_in_snow, 'rb').read()
label = image_labels[cat_in_snow]
# Create a dictionary with features that may be relevant.
def image_example(image_string, label):
image_shape = tf.io.decode_jpeg(image_string).shape
feature = {
'height': _int64_feature(image_shape[0]),
'width': _int64_feature(image_shape[1]),
'depth': _int64_feature(image_shape[2]),
'label': _int64_feature(label),
'image_raw': _bytes_feature(image_string),
}
return tf.train.Example(features=tf.train.Features(feature=feature))
for line in str(image_example(image_string, label)).split('\n')[:15]:
print(line)
print('...')
# Write the raw image files to `images.tfrecords`.
# First, process the two images into `tf.train.Example` messages.
# Then, write to a `.tfrecords` file.
record_file = 'images.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
for filename, label in image_labels.items():
image_string = open(filename, 'rb').read()
tf_example = image_example(image_string, label)
writer.write(tf_example.SerializeToString())
# read tfrecord
raw_image_dataset = tf.data.TFRecordDataset('images.tfrecords')
# Create a dictionary describing the features.
image_feature_description = {
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
'image_raw': tf.io.FixedLenFeature([], tf.string),
}
def _parse_image_function(example_proto):
# Parse the input tf.train.Example proto using the dictionary above.
return tf.io.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
for image_features in parsed_image_dataset:
image_raw = image_features['image_raw'].numpy()
display.display(display.Image(data=image_raw))