本文将介绍:
- 构建一个tf.train.Example对象
- 将tf.train.Example对象存入文件中,生成的tf_record文件.
- 使用tf.data的API读取tf_record文件,并实现反序列化
- 将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件.
- 使用tf.data的API读取tf_record压缩文件.
一,构建一个ByteList,FloatList,Int64List的对象
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
# 打印使用的python库的版本信息
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
print(module.__name__, module.__version__)
# tfrecord 是一种文件格式
# -> tf.train.Example可以是一个样本或者一组
# 每个Example-> tf.train.Features -> {"key": tf.train.Feature}
# 每个Feature有不同的格式-> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List
# 1,构建一个ByteList,FloatList,Int64List的对象
# tf.train.ByteList
favorite_books = [name.encode('utf-8') for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist,type(favorite_books_bytelist))
# tf.train.FloatList
hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist,type(hours_floatlist))
# tf.train.Int64List
age_int64list = tf.train.Int64List(value = [42])
print(age_int64list,type(age_int64list))
二,构建一个tf.train.Features对象
# tf.train.Features
features = tf.train.Features(
feature = {
"favorite_books": tf.train.Feature(bytes_list = favorite_books_bytelist),
"hours": tf.train.Feature(float_list = hours_floatlist),
"age": tf.train.Feature(int64_list = age_int64list),
}
)
print(features)
三,构建一个tf.train.Example对象
# tf.train.Example(使用features构建Example)
example = tf.train.Example(features=features)
print(example)
# 将tf.train.Example对象序列化.
serialized_example = example.SerializeToString()
print(serialized_example)
四,将tf.train.Example对象存入文件中,生成的tf_record文件
output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
# 将序列化后的tf.train.Example对象存入文件中.
with tf.io.TFRecordWriter(filename_fullpath) as writer:
for i in range(3):# 写入三次
writer.write(serialized_example)
五,使用tf.data的API读取tf_record文件
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
print(serialized_example_tensor)
六,将tf.train.Example对象反序列化
# 定义每个feature的数据类型
expected_features = {
"favorite_books": tf.io.VarLenFeature(dtype = tf.string), # VarLenFeature代表变长的数据类型
"hours": tf.io.VarLenFeature(dtype = tf.float32),
"age": tf.io.FixedLenFeature([], dtype = tf.int64), # FixedLenFeature代表定长的数据类型
}
# 解析TFRecord数据集
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
example = tf.io.parse_single_example(
serialized_example_tensor,
expected_features)
# print(example)
books = tf.sparse.to_dense(example["favorite_books"],default_value=b"")
for book in books:
print(book.numpy().decode("UTF-8"))
七,将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件
filename_fullpath_zip = filename_fullpath + '.zip'
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
for i in range(3):
writer.write(serialized_example)
八,使用tf.data的API读取tf_record压缩文件
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip],compression_type= "GZIP")
for serialized_example_tensor in dataset_zip:
example = tf.io.parse_single_example(
serialized_example_tensor,
expected_features)
books = tf.sparse.to_dense(example["favorite_books"],default_value=b"")
for book in books:
print(book.numpy().decode("UTF-8"))