4-6 tf.record基础API使用

如何在tf.data中使用tensorflow特有的文件格式tf.record,我们现在没有tf.record文件,所以需要进行一个转化,把现有数据集转成tf.record,然后再用tf.data对tf.record进读取。再集成到keras模型中的训练中去。

首先看一下基本api的使用。

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

首先导入包。与之前一样。

tfrecord是一个文件格式,里面存储的内容都是tf.train.Example,Example可以是一个样本,也可以是一组样本。对于每个Example,其里面都是一个个的feature,tf.train.Features->{“key”:tf.train.Feature},tf.train.Feature有不同的数据格式,字符串、浮点等。

# tfrecord 文件格式
# -> tf.train.Example
#    -> tf.train.Features -> {"key": tf.train.Feature}
#       -> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List

favorite_books = [name.encode('utf-8')
                  for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)

features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(
            float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list),
    }
)
print(features)
example = tf.train.Example(features=features)
print(example)

serialized_example = example.SerializeToString()
print(serialized_example)

example与之前的feature差不多,但是序列化之后的内容就有点看不懂了。

以上就是一个example的具体定义,接下来看如何把example存到文件中去,生成一个具体的tfrecord文件。

output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

首先定义输出目录,再定义文件名和全路径。因为存储的是tfrecord文件,所以也用相应的方法来打开。打开之后把刚才定义好 的序列化example写进去三次。运行之后去文件夹之下看就会发现文件夹已经创建好了,文件也生成好了。

接下来用tfrecordAPI读取文件。与之前的方法类似。

dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

调用tf.data.TFRecordDataset方法,参数是一个列表,列表里面再放入全路径。遍历输出字符串之后发现与之前定义的是一样的。

那如何把序列化之后的example解析成肉眼可见的正常的example呢?

之前解析csv的时候是定义了一个列表,表明了每个field的类型。与之类似,这里定义一个字典,字典中定义了每个feature对应的类型。

expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string),
    "hours": tf.io.VarLenFeature(dtype = tf.float32),
    "age": tf.io.FixedLenFeature([], dtype = tf.int64),
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

VarLenFeature:是指变长的feature;
FixedLenFeature:是指定长的feature,
指定类型之后不再直接打印,而是用tf.io.parse_single_example解析一下,传进去两个参数,一个是序列化后具体的tensor,另一个是我们期待的类型。
再读取其中的一个feature

还可以生成tdrecord的压缩文件

filename_fullpath_zip = filename_fullpath + '.zip'
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

在路径后面加上zip,要想得到压缩格式,就要定义options,打开文件的时候再把option传进去。

打开文件夹可以发现压缩后的文件更加小,那么如何读取压缩后的文件呢?

dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], 
                                      compression_type= "GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

读取的方法就是在创建dataset的时候把多传一个参数:compression_type= “GZIP”,其他地方不变.
下一节看看如何在正常的代码中使用tfrecord。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值