[Tensorflow]关于TFRecord和tf.Example的使用

为高效读取数据,可将其序列化存储,TFRecord是常用的存储二进制序列数据的方法。博客介绍了TFRecord的数据类型、创建方法,还阐述了使用Tensorflow读写TFRecord文件的方式,并给出了读写实例,同时提供了相关参考链接。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

 为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord是一种比较常用的存储二进制序列数据的方法,基于Google的Protocol buffers格式的数据。

tf.Example类是一种将数据表示为{"string": value}形式的meassage类型,Tensorflow经常使用tf.Example来写入、读取TFRecord数据

1. 关于tf.Example

1.1 tf.Example的数据类型

 一般来说,tf.Example都是{"string": tf.train.Feature}这样的键值映射形式。其中,tf.train.Feature类可以使用以下3种类型

  • tf.train.BytesList: 可以使用的类型包括 stringbyte

  • tf.train.FloatList: 可以使用的类型包括 floatdouble

  • tf.train.Int64List: 可以使用的类型包括 enum,bool, int32, uint32, int64以及uint64

 为了将常用的数据类型(标量或list),转化为tf.Example兼容的tf.train.Feature类型,通过使用以下几个接口函数:

# 这里括号中的value是一个标量
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
    
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

为了示例的简洁,这里只是使用了变量。如果想要对张量进行处理,常用的方法是:使用tf.serialize_tensor函数将张量转化为二进制字符,然后使用_bytes_feature()进行处理;读取的时候使用tf.parse_tensor对二进制字符转换为Tensor类型。

 下面举一个简单例子了解一下,经过tf.train.Feature转换之后的结果

print(_bytes_feature(b'test_string'))

## 输出为: 
## bytes_list {
   
##   value: "test_string"    
## }

 所有的proto meassages都可以通过.SerializeToString方法转换为二进制字符串

feature = _float_feature(np.exp(1))
feature.SerializeToString()

# 输出为:b'\x12\x06\n\x04T\xf8-@'

1.2 创建一个tf.Example数据

 无论是什么类型,基于已有的数据构造tf.Example数据的流程是相同的:

  • 对于一个观测值,需要转化为上面所说的tf.train.Feature兼容的3种类型之一;

  • 构造一个字典映射,键key是string型的feature名称,值value
    是第1步中转换得到的值;

  • 第2步中得到的映射会被转换为Features类型的数据

假设存在一个数据集,包含4个特征:1个bool型,1个int型,1个string型以及1个float型;假设数据集的数量为10000

n_obeservations = int(1e4)

# bool类型的特征
feature0 = np.random.choice([False, True], n_observations)

# int型特征
feature1 = np.random.randint(0, 5, n_observations)

# string型特征
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]

# float型特征
feature3 = np.random.randn(n_observations)

 定义一个将各种类型封装的序列化函数

def serialize(feature0, feature1, feature2, feature3):
    feature = {
   
        "feature0": _int64_feature(feature0),
        "feature1": _int64_feature(feature1),
        "fe
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值