[tensorflow]关于使用TFRecordDataset的对数据序列化存储和解码详解(附带官方demo)

TFRecordDataset以字节的形式从文件中加载TFRecords,本身不进行任何分析或解码。分析和解码可以通过在“TFRecordDataset”之后,应用Dataset.map(fn),来完成。

1.存储

使用tf.io.TFRecordWriter(example_path)进行创建写入对象file_writer,然后通过file_writer.write(record_bytes)进行存储.
存储的对象record_bytes需要是经过.SerializeToString()序列化成字符串的变量.
record_bytes是有比较固定的设置模式的:

feature_diy={
	'data1_name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[data1]))
}

record_bytes=tf.train.Example(features=tf.train.Features(feature={feature_diy}))

实际例子:

with tf.io.TFRecordWriter(example_path) as file_writer:
    for _ in range(4):
        x, y = np.random.random(), np.random.random()
		feature_diy = {
			"x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
			"y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),  
		}
        record_bytes = tf.train.Example(features=tf.train.Features(feature=feature_diy)).SerializeToString()
        file_writer.write(record_bytes)

2.解码

解码分为两步:
第一步通过train_dataset = tf.data.TFRecordDataset(example_path)获取序列化成字符的数据,
第二步通过train_dataset = train_dataset.map(decode_function)对序列化的字符进行解码
解码函数采用以下结构定义:

def decode_fn(record_bytes):
    feature_dick = tf.io.parse_single_example(
        # Data
        record_bytes,
        # Schema
        feature_description
    )
    return feature_dick

解码后返回的是一个迭代器,可以通过for batch in feature_dick:遍历每一个batch.每一个batch都是一个字典变量,字典的key是data_name,字典的value就是存储的值.

  • 拓展:得到迭代器之后,还可以设置数据的batch和是否shuffle_buff:
data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
data_decode.shuffle(cfg.shuffle_buff).batch(cfg.batch_size, drop_remainder=True)

batch:

函数功能:
将此数据集的连续元素合并为批次。

参数解释:
batch_size: 表示此数据集的连续batch_size个元素将在单个batch中组合。
drop_rement:(可选)表示如果最后一批的数量少于,是否应丢弃该批`batch_size’元素;默认不删除较小的一批.一般最好选择True

shuffle:

函数功能:
随机洗牌此数据集的元素。

参数解释:
buffer_size:表示新数据集将从此数据集中采样的元素数。
seed: 随机种子
reshuffle_each_iteration: 默认为True,如果为true,则表示数据集应在每次迭代时进行伪随机重组

  • 当设置batch之后,采用enumerate方式读取数据:
for batch_idx, (data1, data2, data3) in enumerate(train_data):
	#按照batch的方式读取数据

实际例子:


# 定义Feature结构,告诉编码器每个Feature的类型是什么
feature_description = {
    "x": tf.io.FixedLenFeature([], dtype=tf.float32),
    "y": tf.io.FixedLenFeature([], dtype=tf.float32)
}

def decode_fn(record_bytes):
    feature_dick = tf.io.parse_single_example(
        # Data
        record_bytes,
        # Schema
        feature_description
    )
    return feature_dick

data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
for batch in data_decode:
    print("x = {x:.4f},  y = {y:.4f}".format(**batch))

3. 官方代码实现和改良

稍微拆分了官方代码,便于理解.

#!/usr/bin/env python
# coding: utf-8

# TFRecordDataset以字节的形式从文件中加载TFRecords,本身不进行任何分析或解码。
# 分析和解码可以通过在“TFRecordDataset”之后,应用Dataset.map(fn),来完成。

import tempfile
import os
import numpy as np
import tensorflow as tf
example_path = os.path.join(tempfile.gettempdir(), "example.tfrecords")
np.random.seed(0)

with tf.io.TFRecordWriter(example_path) as file_writer:
    for _ in range(2):
        x, y = np.random.random(), np.random.random()
        # 定义存储的特征的名称,结构和数值
        feature_data = {
            "x": tf.train.Feature(float_list=tf.train.FloatList(value=[x])),
            "y": tf.train.Feature(float_list=tf.train.FloatList(value=[y])),  
        }
        # 得到数据结构体
        record_bytes = tf.train.Example(features=tf.train.Features(feature=feature_data)) 
        # 将结构体序列化成字符串
        record_bytes=record_bytes.SerializeToString() 
        print(record_bytes)
#         存储到file_writer中,也就是example.tfrecords中
        file_writer.write(record_bytes)

# 定义Feature结构,告诉编码器每个Feature的类型是什么
feature_description = {
    "x": tf.io.FixedLenFeature([], dtype=tf.float32),
    "y": tf.io.FixedLenFeature([], dtype=tf.float32)
}

def decode_fn(record_bytes):
    feature_dick = tf.io.parse_single_example(
        # Data
        record_bytes,
        # Schema
        feature_description
    )
    return feature_dick

data_decode = tf.data.TFRecordDataset([example_path]).map(decode_fn)
for batch in data_decode:
    print("x = {x:.4f},  y = {y:.4f}".format(**batch))


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值