背景
现在做深度学习,大部分用户都使用的 PyTorch
. 但是因为历史原因,以及在线推理的速度,还有一些业务使用的是 TensorFlow
.
TensorFlow
的数据格式是 tfrecord
,这个格式不太好用,这里试图记录 在不清楚数据中各个特征类型的情况写去解析这份数据的过程。
代码
下面这份代码会 解析 .gz
后缀的一份 tfrecord
数据,并返回 3 个 字典,包含 这份数据的第 1 条内容,包含各个特征的数值、长度、类型。
import tensorflow as tf
# 如果是 tensorflow 1 , 需要打开 eager 模式
tf.enable_eager_execution()
def get_infos_from_feature_dict(feature_dict):
"""
这个函数用于 解析 feature_dict 中的每一个 Feature.
每一个 Feature 均包含 bytes_list, float_list, int64_list. 需要解析.
只能解析固定长度的 Feature. 因为只解析 DataSet 中的 第 1 个.
Args:
feature_dict: key 是 string, value 是 tf.core.example.feature_pb2.Feature
"""
feature_values = dict()
feature_lens = dict()
feature_types = dict()
for feature, value in feature_dict.items():
bytes_list = value.bytes_list.value
float_list = value.float_list.value
int64_list = value.int64_list.value
feature_len = 0
feature_type = "unknown"
feature_value = None
for i, each_type_value in enumerate((bytes_list, float_list, int64_list)):
each_type_len = len(each_type_value)
if each_type_len > 0:
if feature_len or feature_type != "unknown" or feature_value is not None:
raise ValueError(f"{feature} has more than 1 type.")
feature_type = ("string", "float", "int64")[i]
feature_len = each_type_len
feature_value = each_type_value
feature_values.update({feature: feature_value})
feature_lens.update({feature: feature_len})
feature_types.update({feature: feature_type})
return feature_values, feature_lens, feature_types
def get_first_data_from_tfrecord(tf_record_path):
"""
解析 并返回 tf_record 的第 1 条数据。
Returns:
3 个 dict, key 为 特征名, value 分别为 特征的具体数值(list), 特征的长度, 特征的类型("string", "float" 或 "int64")
"""
# 将文件转换为 TFRecordDataset. 如果文件有压缩,需要填写 compression_type. ".gz" 后缀对应 "GZIP"
dataset = tf.data.TFRecordDataset(tf_record_path, compression_type="GZIP")
# 将 dataset 转为 迭代器 并 取第1个 数据.
eager_tensor = dataset.__iter__().__next__() # 类型是 tensorflow.python.framework.ops.EagerTensor
example_bytes = eager_tensor.numpy()
# 使用 example 解析 eager_tensor
example = tf.train.Example()
example.ParseFromString(example_bytes)
# 取出 feature 的内容
feature_dict = dict(example.features.feature)
feature_values, feature_lens, feature_types = get_infos_from_feature_dict(feature_dict)
return feature_values, feature_lens, feature_types
if __name__ == "__main__":
data_path = r"/my_tfrecord_data01.tfrecord.gz"
values, lens, types = get_first_data_from_tfrecord(train_data_path)
总结
具体过程已经写在代码的注释中。
主要有如下几点要注意:
- 如果使用的是
TensorFLow
1.xx 的本版,需要打开 eager 模式。 - 根据数据是否有压缩,需要在
tf.data.TFRecordDataset
函数中填写compression_type
参考官网链接点击此处 - 使用
tf.train.Example()
解析后,可以将example.features.feature
转换为 python 字典,它的 key 是 特征名,value 是Feature
,这里的Feature
是 TF 自己的一个 class,每个Feature
都包含bytes_list, float_list, int64_list
, 需要根据具体内容解析这个特征到底是什么类型的。参考函数get_infos_from_feature_dict
.
如果这篇博客有帮到你,欢迎点赞、收藏、关注。谢谢~