为了高效地读取数据,可以将数据进行序列化存储,这样也便于网络流式读取数据。TFRecord
是一种比较常用的存储二进制序列数据的方法,基于Google的Protocol buffers格式的数据。
tf.Example
类是一种将数据表示为{"string": value}
形式的meassage类型,Tensorflow经常使用tf.Example
来写入、读取TFRecord
数据
1. 关于tf.Example
1.1 tf.Example
的数据类型
一般来说,tf.Example
都是{"string": tf.train.Feature}
这样的键值映射形式。其中,tf.train.Feature
类可以使用以下3种类型
-
tf.train.BytesList
: 可以使用的类型包括string
和byte
-
tf.train.FloatList
: 可以使用的类型包括float
和double
-
tf.train.Int64List
: 可以使用的类型包括enum
,bool
,int32
,uint32
,int64
以及uint64
。
为了将常用的数据类型(标量或list),转化为tf.Example
兼容的tf.train.Feature
类型,通过使用以下几个接口函数:
# 这里括号中的value是一个标量
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
为了示例的简洁,这里只是使用了变量。如果想要对张量进行处理,常用的方法是:使用
tf.serialize_tensor
函数将张量转化为二进制字符,然后使用_bytes_feature()
进行处理;读取的时候使用tf.parse_tensor
对二进制字符转换为Tensor
类型。
下面举一个简单例子了解一下,经过tf.train.Feature
转换之后的结果
print(_bytes_feature(b'test_string'))
## 输出为:
## bytes_list {
## value: "test_string"
## }
所有的proto meassages
都可以通过.SerializeToString
方法转换为二进制字符串
feature = _float_feature(np.exp(1))
feature.SerializeToString()
# 输出为:b'\x12\x06\n\x04T\xf8-@'
1.2 创建一个tf.Example
数据
无论是什么类型,基于已有的数据构造tf.Example
数据的流程是相同的:
-
对于一个观测值,需要转化为上面所说的
tf.train.Feature
兼容的3种类型之一; -
构造一个字典映射,键
key
是string型的feature
名称,值value
是第1步中转换得到的值; -
第2步中得到的映射会被转换为
Features
类型的数据
假设存在一个数据集,包含4个特征:1个bool型,1个int型,1个string型以及1个float型;假设数据集的数量为
10000
:
n_obeservations = int(1e4)
# bool类型的特征
feature0 = np.random.choice([False, True], n_observations)
# int型特征
feature1 = np.random.randint(0, 5, n_observations)
# string型特征
strings = np.array([b'cat', b'dog', b'chicken', b'horse', b'goat'])
feature2 = strings[feature1]
# float型特征
feature3 = np.random.randn(n_observations)
定义一个将各种类型封装的序列化函数
def serialize(feature0, feature1, feature2, feature3):
feature = {
"feature0": _int64_feature(feature0),
"feature1": _int64_feature(feature1),
"fe