TFRecordDataset以字节的形式从文件中加载TFRecords,本身不进行任何分析或解码。分析和解码可以通过在“TFRecordDataset”之后,应用Dataset.map(fn),来完成。
1.存储
使用tf.io.TFRecordWriter(example_path)
进行创建写入对象file_writer
,然后通过file_writer.write(record_bytes)
进行存储.
存储的对象record_bytes
需要是经过.SerializeToString()
序列化成字符串的变量.
record_bytes是有比较固定的设置模式的:
feature_diy={
'data1_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data1]))
}
record_bytes=tf.train.Example(features=tf.train.Features(feature={feature_diy}))
实际例子:
with tf.io.TFRecordWriter(example_path) as file_writer:
for _ in range(4):
x, y = np.random.random(), np.random.random()
feature_diy = {
"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
}
record_bytes = tf.train.Example(features=tf.train.Features(feature=feature_diy)).SerializeToString()
file_writer.write(record_bytes)
2.解码
解码分为两步:
第一步通过train_dataset = tf.data.TFRecordDataset(example_path)
获取序列化成字符的数据,
第二步通过train_dataset = train_dataset.map(decode_function)
对序列化的字符进行解码
解码函数采用以下结构定义:
def decode_fn(record_bytes):
feature_dick = tf.io.parse_single_example(
# Data
record_bytes,
# Schema
feature_description
)
return feature_dick
解码后返回的是一个迭代器,可以通过for batch in feature_dick:
遍历每一个batch.每一个batch都是一个字典变量,字典的key是data_name,字典的value就是存储的值.
- 拓展:得到迭代器之后,还可以设置数据的batch和是否shuffle_buff:
data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
data_decode.shuffle(cfg.shuffle_buff).batch(cfg.batch_size, drop_remainder=True)
batch:
函数功能:
将此数据集的连续元素合并为批次。
参数解释:
batch_size: 表示此数据集的连续batch_size个元素将在单个batch中组合。
drop_rement:(可选)表示如果最后一批的数量少于,是否应丢弃该批`batch_size’元素;默认不删除较小的一批.一般最好选择True
shuffle:
函数功能:
随机洗牌此数据集的元素。
参数解释:
buffer_size:表示新数据集将从此数据集中采样的元素数。
seed: 随机种子
reshuffle_each_iteration: 默认为True,如果为true,则表示数据集应在每次迭代时进行伪随机重组
- 当设置batch之后,采用
enumerate
方式读取数据:
for batch_idx, (data1, data2, data3) in enumerate(train_data):
#按照batch的方式读取数据
实际例子:
# 定义Feature结构,告诉编码器每个Feature的类型是什么
feature_description = {
"x": tf.io.FixedLenFeature([], dtype=tf.float32),
"y": tf.io.FixedLenFeature([], dtype=tf.float32)
}
def decode_fn(record_bytes):
feature_dick = tf.io.parse_single_example(
# Data
record_bytes,
# Schema
feature_description
)
return feature_dick
data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
for batch in data_decode:
print("x = {x:.4f}, y = {y:.4f}".format(**batch))
3. 官方代码实现和改良
稍微拆分了官方代码,便于理解.
#!/usr/bin/env python
# coding: utf-8
# TFRecordDataset以字节的形式从文件中加载TFRecords,本身不进行任何分析或解码。
# 分析和解码可以通过在“TFRecordDataset”之后,应用Dataset.map(fn),来完成。
import tempfile
import os
import numpy as np
import tensorflow as tf
example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
np.random.seed(0)
with tf.io.TFRecordWriter(example_path) as file_writer:
for _ in range(2):
x, y = np.random.random(), np.random.random()
# 定义存储的特征的名称,结构和数值
feature_data = {
"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),
}
# 得到数据结构体
record_bytes = tf.train.Example(features=tf.train.Features(feature=feature_data))
# 将结构体序列化成字符串
record_bytes=record_bytes.SerializeToString()
print(record_bytes)
# 存储到file_writer中,也就是example.tfrecords中
file_writer.write(record_bytes)
# 定义Feature结构,告诉编码器每个Feature的类型是什么
feature_description = {
"x": tf.io.FixedLenFeature([], dtype=tf.float32),
"y": tf.io.FixedLenFeature([], dtype=tf.float32)
}
def decode_fn(record_bytes):
feature_dick = tf.io.parse_single_example(
# Data
record_bytes,
# Schema
feature_description
)
return feature_dick
data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
for batch in data_decode:
print("x = {x:.4f}, y = {y:.4f}".format(**batch))