TFRecord基础API的使用
首先介绍一下tfrecord的文件组成格式:
tfrecord是有example组成,example是由features组成,而features是由feature组成。
#tfrecord 文件格式
# --> tf.train.Example : 文件由Example组成
# --> tf.train.Features -> { feature{"key" : tf.train.Feature(),value = {...}},
# feature{...},
# ... } : Example由Features组成
# --> tf.train.Feature -->(格式可选)tf.train.ByteList/FloatList/Int64List : Features由Feature组成
栗子:
features {
feature {
key: "age"
value {
int64_list {
value: 42
}
}
}
feature {
key: "favorite_books"
value {
bytes_list {
value: "computer vision "
value: "machine learing"
}
}
}
feature {
key: "house"
value {
float_list {
value: 15.5
value: 9.5
value: 7.0
value: 8.0
}
}
}
}
使用到的函数:
定义数据类型:
tf.train.ByteList(value = repeated bytes value)
;
tf.train.FloatList(value = repeated float value)
;
tf.train.Int64List(value = repeated int64 value)
.
定义Feature:
tf.train.Feature([op])
op:
- bytes_list: BytesList bytes_list;
- float_list: FloatList float_list;
- int64_list: Int64List int64_list.
定义Features:
tf.train.Features(feature: repeated FeatureEntry feature)
写tfrecord文件:
tf.io.TFRecordWriter( path, options=None )
Args:
- path : 写入文件的路径;
- options=None: 指定压缩类型的字符串,TFRecordCompressionType, or TFRecordOptions对象。
压缩操作:
tf.io.TFRecordOptions( compression_type=None, )
Args:
compression_type=None:压缩格式,“GZIP”, “ZLIB”, or “” (no compression)。
读取tfrecord形成数据:
tf.data.TFRecordDataset(
filenames
,
compression_type=None
,
buffer_size=None
,
num_parallel_reads=None
)
解析单个序列化的example为数据原型:
tf.io.parse_single_example(
serialized
,
features
,
)
Args:
serialized
:标量字符串张量,一个序列化的example;features
: 一个字典,将Feature key映射到FixedLenFeature或VarLenFeature。
代码示例:
import tensorflow as tf
import os
#tfrecord 文件格式
# --> tf.train.Example : 文件由Example组成
# --> tf.train.Features -> { feature{"key" : tf.train.Feature(),value = {...}},
# feature{...},
# ... } : Example由Features组成
# --> tf.train.Feature -->(格式可选)tf.train.ByteList/FloatList/Int64List : Features由Feature组成
#1.构建Feature
favorite_books = [name.encode('utf-8') for name in ["computer vision ", "machine learing"]]
#定义数据类型
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print("favorite_books_bytelist : \n",favorite_books_bytelist )
hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print("hours_floatlist : \n", hours_floatlist)
age_int64list = tf.train.Int64List(value = [42])
print("age_int64list : \n",age_int64list)
#2.构建Features:{"key" : feature}
features = tf.train.Features(
feature = {
"favorite_books" : tf.train.Feature(bytes_list = favorite_books_bytelist),
"house" : tf.train.Feature(float_list = hours_floatlist),
"age" : tf.train.Feature(int64_list = age_int64list),
}
)
print("features : \n", features)
favorite_books_bytelist :
value: "computer vision "
value: "machine learing"
hours_floatlist :
value: 15.5
value: 9.5
value: 7.0
value: 8.0
age_int64list :
value: 42
features :
feature {
key: "age"
value {
int64_list {
value: 42
}
}
}
feature {
key: "favorite_books"
value {
bytes_list {
value: "computer vision "
value: "machine learing"
}
}
}
feature {
key: "house"
value {
float_list {
value: 15.5
value: 9.5
value: 7.0
value: 8.0
}
}
}
#3.构建Example
example = tf.train.Example(features = features)
print(example)
#序列化对象
#序列化 (Serialization):将对象的状态信息转换为可以存储或传输的形式的过程,与序列化相对的是反序列化,它将流转换为对象
serialized_example = example.SerializeToString()
print(serialized_example)
features {
feature {
key: "age"
value {
int64_list {
value: 42
}
}
}
feature {
key: "favorite_books"
value {
bytes_list {
value: "computer vision "
value: "machine learing"
}
}
}
feature {
key: "house"
value {
float_list {
value: 15.5
value: 9.5
value: 7.0
value: 8.0
}
}
}
}
b'\nf\n7\n\x0efavorite_books\x12%\n#\n\x10computer vision \n\x0fmachine learing\n\x1d\n\x05house\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*'
#存储example为tfrecord文件格式
output_dir = "tfrecord_basic"
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as write:
for i in range(3):
write.write(serialized_example)
#读取tfrecord文件
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
print(serialized_example_tensor)
tf.Tensor(b'\nf\n7\n\x0efavorite_books\x12%\n#\n\x10computer vision \n\x0fmachine learing\n\x1d\n\x05house\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*', shape=(), dtype=string)
tf.Tensor(b'\nf\n7\n\x0efavorite_books\x12%\n#\n\x10computer vision \n\x0fmachine learing\n\x1d\n\x05house\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*', shape=(), dtype=string)
tf.Tensor(b'\nf\n7\n\x0efavorite_books\x12%\n#\n\x10computer vision \n\x0fmachine learing\n\x1d\n\x05house\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*', shape=(), dtype=string)
#反序列化,解析文件
#定义数据类型; VarLenFeature : 可变长的Feature; FixedLenFeature:定长的Feature
expected_features = {
"favorite_books" : tf.io.VarLenFeature(dtype=tf.string),
"hours" : tf.io.VarLenFeature(dtype=tf.float32),
"age" : tf.io.FixedLenFeature([], dtype=tf.int64),
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
#解析文件,和解析csv文件类似(tf.io.decode_csv)
for serialized_example_tensor in dataset:
example = tf.io.parse_single_example(
serialized_example_tensor,
expected_features,
)
print(example)
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D35E1C48>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D44F4408>, 'age': <tf.Tensor: id=406, shape=(), dtype=int64, numpy=42>}
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D35D9048>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D35CD748>, 'age': <tf.Tensor: id=415, shape=(), dtype=int64, numpy=42>}
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D395D808>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000202D6EC3408>, 'age': <tf.Tensor: id=424, shape=(), dtype=int64, numpy=42>}
#转换稀疏矩阵为正常的矩阵
for serialized_example_tensor in dataset:
example = tf.io.parse_single_example(
serialized_example_tensor,
expected_features,)
books = tf.sparse.to_dense(example["favorite_books"])
for book in books:
print(book.numpy().decode("UTF-8"))
computer vision
machine learing
computer vision
machine learing
computer vision
machine learing
#将tfrecord存储为压缩格式 : tf.io.TFRecordOptions()
filename_fullpath_zip = filename_fullpath + ".zip"
options = tf.io.TFRecordOptions(compression_type="GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as write:
for i in range(3):
write.write(serialized_example)
#读取压缩文件
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], compression_type="GZIP")
#解析文件,和解析csv文件类似(tf.io.decode_csv)
for serialized_example_tensor in dataset_zip:
example = tf.io.parse_single_example(
serialized_example_tensor,
expected_features,)
books = tf.sparse.to_dense(example["favorite_books"], default_value=b"")
for book in books:
print(book.numpy().decode("UTF-8"))
computer vision
machine learing
computer vision
machine learing
computer vision
machine learing