在文章tfrecord格式的内容解析及样例 中我们已经分析了tfrecord 的内容是什么格式,接下来就要学习tfrecord怎么使用,及tfrecord的读写。
生成tfrecord
tfrecord文件的写入非常简单,仍然用tfrecord格式的内容解析及样例 中的例子,我们首先生成一个example
value_city = u"北京".encode('utf-8') # 城市
value_use_day = 7 #最近7天打开淘宝次数
value_pay = 289.4 # 最近7 天消费金额
value_poi = [b"123", b"456", b"789"] #最近7天浏览电铺
'''
下面生成ByteList,Int64List和FloatList
'''
bl_city = tf.train.BytesList(value = [value_city]) ## tf.train.ByteList入参是list,所以要转为list
il_use_day = tf.train.Int64List(value = [value_use_day])
fl_pay = tf.train.FloatList(value = [value_pay])
bl_poi = tf.train.BytesList(value = value_poi)
'''
下面生成tf.train.Feature
'''
feature_city = tf.train.Feature(bytes_list = bl_city)
feature_use_day = tf.train.Feature(int64_list = il_use_day)
feature_pay = tf.train.Feature(float_list = fl_pay)
feature_poi = tf.train.Feature(bytes_list = bl_poi)
'''
下面定义tf.train.Features
'''
feature_dict = {"city":feature_city,"use_day":feature_use_day,"pay":feature_pay,"poi":feature_poi}
features = tf.train.Features(feature = feature_dict)
'''
下面定义tf.train.example
'''
example = tf.train.Example(features = features)
然后就是把这个example写入文件中
path = "./tfrecord"
with tf.io.TFRecordWriter(path) as file_writer:
file_writer.write(example.SerializeToString())
至此,就完成了tfrecord文件的写入。
当然,到这里还没完,用tf.io写入example的字节和直接用Python的写入example的字节 是一样的吗? 为此我们做一个实验
path = "./tfrecord"
path2 = "./tfrecord2"
with tf.io.TFRecordWriter(path) as file_writer:
file_writer.write(example.SerializeToString())
with open(path2,"wb") as f:
f.write(example.SerializeToString())
通过上面的代码,我们分别通过tf.io和Python的open方法把example的字节写入2个文件。比较大小后发现一个是86字节,一个是99字节。看来内容还是不一样的,所以不能用Python自带的open方法代替tf.io
tfrecord读取
tfrecord的读取也很简单,但是tensorflow的官方document写的真的非常糟糕,以下全部是我个人摸索出来的结果。接上代码
path = "./tfrecord"
data = tf.data.TFRecordDataset(pathtensor)
以上实际上就已经完成了tfrecord的读取过程。很多人会说,可是无论平时使用还是工程中,都会用一个map方法对data进行变换呀。没错,如果使用需要进行变换,这是因为我们在保存tfrecord的时候,先把一个example序列化成二进制,然后再把二进制字节变成一个string,这样每个example就是一个string保存在了tfrecord 中。而读取过程同样,通过tf.data.TFRecordDataset,我们已经把每个example变成的string以 tf.tensor(dtype=string) 的方式读取进来了。所以我们完全可以用下面代码看读取结果
for batch in data:
print(batch)
result:
tf.Tensor(b'\nQ\n\x18\n\x03poi\x12\x11\n\x0f\n\x03123\n\x03456\n\x03789\n\x12\n\x04city\x12\n\n\x08\n\x06\xe5\x8c\x97\xe4\xba\xac\n\x10\n\x07use_day\x12\x05\x1a\x03\n\x01\x07\n\x0f\n\x03pay\x12\x08\x12\x06\n\x043\xb3\x90C', shape=(), dtype=string)
这里还有另外一个大坑,data是一个TFRecordDatasetV2类,但同时,它也是个可迭代对象,所以就算找遍它的所有属性和方法,都找不到它保存数据的tensor,但是可以通过迭代看到。
在Python中,可迭代对象是指有__iter__属性的对象,这类对象可以用循环取迭代,所以可以放在for中迭代,其他对象例如整型,float等不是可迭代对象,放在循环中会报错 “object is not iterable”。
当然只是把example序列化的字节,读取出来是不能用的,我们还是要把其中数据解析出来,这时候就要用到熟悉的map 方法了
def decode_fn(record_bytes):
return tf.io.parse_single_example(
record_bytes,
{
"city":tf.io.FixedLenFeature([],dtype = tf.string),
"use_day":tf.io.FixedLenFeature([],dtype = tf.int64),
"pay":tf.io.FixedLenFeature([],dtype = tf.float32)
,"poi":tf.io.VarLenFeature(dtype=tf.string)
})
data2 = data.map(decode_fn)
tf.io.parse_single_example 输入是一个string的tensor 输出是一个 dict ,格式就是如入参中的格式,应该注意的是,入参中的key应该去全部在example中出现过,否则会报错。
在弄懂了data的内容之后,我们就可以通过下面的方法调用decode_fn:
for batch in data:
print(decode_fn(batch))
result:
{'poi': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x000002532FEF7860>, 'city': <tf.Tensor: id=55, shape=(), dtype=string, numpy=b'\xe5\x8c\x97\xe4\xba\xac'>, 'pay': <tf.Tensor: id=56, shape=(), dtype=float32, numpy=289.4>, 'use_day': <tf.Tensor: id=58, shape=(), dtype=int64, numpy=7>}
可以看到2种读取方法内容是一样的。
在这里还有一个问题,在tf的官方教程中,io接口下有2个很类似的函数:tf.io.parse_single_example和tf.io.parse_example。这两个有什么区别呢?
1. 解析的example规模不同。
我们先来看官方的文档
tf.io.parse_example的官方文档如下
Args | |
---|---|
serialized | A vector (1-D Tensor) of strings, a batch of binary serialized Example protos. |
features | A dict mapping feature keys to FixedLenFeature , VarLenFeature , SparseFeature , and RaggedFeature values. |
example_names | A vector (1-D Tensor) of strings (optional), the names of the serialized protos in the batch. |
name | A name for this operation (optional). |
Returns | |
---|---|
A dict mapping feature keys to Tensor , SparseTensor , and RaggedTensor values. |
tf.io.parse_single_example官方文档如下
Args | |
---|---|
serialized | A scalar string Tensor, a single serialized Example. |
features | A dict mapping feature keys to FixedLenFeature or VarLenFeature values. |
example_names | (Optional) A scalar string Tensor, the associated name. |
name | A name for this operation (optional). |
Returns | |
---|---|
A dict mapping feature keys to Tensor and SparseTensor values. |
通过官方给的定义和函数的名字就可以看出来,tf.io.parse_single_example只对单条example的二进制序列进行解析,得到的也就是一个example,所以他的第一个入参要求是scalar string Tensor,即标量tensor,其实就是一个字符串。所以在上面的例子中
for batch in data:
print(batch)
result:
tf.Tensor(b'\nQ\n\x18\n\x03poi\x12\x11\n\x0f\n\x03123\n\x03456\n\x03789\n\x12\n\x04city\x12\n\n\x08\n\x06\xe5\x8c\x97\xe4\xba\xac\n\x10\n\x07use_day\x12\x05\x1a\x03\n\x01\x07\n\x0f\n\x03pay\x12\x08\x12\x06\n\x043\xb3\x90C', shape=(), dtype=string)
result看似是一个tensor,但它没有形状,所以说本质上还是一个标量(字符串),并非张量
tensorflow中有三个概念
标量(scalar tensor),也可以认为就是普通的变量,是0阶张量,shape一般是空
向量(vector),就是一阶张量
张量,不用解释,用的最多
那如果把标量变形成一个向量或者张量,这样的入参不符合parse_single_example的入参定义,就会报错
而tf.io.parse_example正好相反,tf.io.parse_example可以解析一批example,所以他的入参是一个向量,就算是只对一个example进行解析,也必须把标量变形成向量,也就是说应该写成
def decode_fn(record_bytes):
return tf.io.parse_example(
tf.reshape(record_bytes,[1]), #注意这一行发生了变化
{
"city":tf.io.FixedLenFeature([],dtype = tf.string),
"use_day":tf.io.FixedLenFeature([],dtype = tf.int64),
"pay":tf.io.FixedLenFeature([],dtype = tf.float32)
,"poi":tf.io.VarLenFeature(dtype=tf.string)
})
data2 = data.map(decode_fn)
这里应该注意,tf.io.parse_example的第一个入参只能是向量,绝对不能是二维以上的张量,否则同样报错。
2.对可变长sparse特征的解析结果不同
这个区别是非常有趣的,我们来看上面的poi这个特征,他是一个sparse特征,无论是通过tf.io.parse_example 还是tf.io.parse_single_example,我们都是把字符串解析了出来,得到了 ["123", "456", "789"]三个店铺id,但实际上一般都要对这类特征进行onehot,变成数值类型的输入。
用tf.io.parse_example得到的onrhot编码是一个向量例如,假设一共有5家店铺[a,"123", b, "456", "789"]。那么用tf.io.parse_example,在经过onehot会得到[0,1,0,1,1],而parse_single_example会得到
[[0,1,0,0,0]
[0,0,0,1,0]
0,0,0,0,1]]
这个会在https://blog.csdn.net/kangshuangzhu/article/details/106851826中详细介绍
结语
这里还有一个问题,在定义tf.io.parse_single_example的时候,我们需要给出返回的dict的形式。当特征数量较少的时候这当然没问题,但是工程中一般特征非常多,动辄上千维,用这种方法定义很明显是非常低效的。这时候tf.feature_column就是一个非常有用的工具了。tf.feature_column的内容下一篇文章再进行讲解