Keras(十五)tf_record基础API使用

38 篇文章 2 订阅
35 篇文章 11 订阅

本文将介绍:

  • 构建一个tf.train.Example对象
  • 将tf.train.Example对象存入文件中,生成的tf_record文件.
  • 使用tf.data的API读取tf_record文件,并实现反序列化
  • 将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件.
  • 使用tf.data的API读取tf_record压缩文件.

一,构建一个ByteList,FloatList,Int64List的对象

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras

# 打印使用的python库的版本信息
print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)
    
# tfrecord 是一种文件格式
# -> tf.train.Example可以是一个样本或者一组
#    每个Example-> tf.train.Features -> {"key": tf.train.Feature}
#       每个Feature有不同的格式-> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List

# 1,构建一个ByteList,FloatList,Int64List的对象
# tf.train.ByteList
favorite_books = [name.encode('utf-8') for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist,type(favorite_books_bytelist))

# tf.train.FloatList
hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist,type(hours_floatlist))

# tf.train.Int64List
age_int64list = tf.train.Int64List(value = [42])
print(age_int64list,type(age_int64list))

二,构建一个tf.train.Features对象

# tf.train.Features
features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list),
    }
) 
print(features)

三,构建一个tf.train.Example对象

# tf.train.Example(使用features构建Example)
example = tf.train.Example(features=features)
print(example)

# 将tf.train.Example对象序列化.
serialized_example = example.SerializeToString()
print(serialized_example)

四,将tf.train.Example对象存入文件中,生成的tf_record文件

output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
# 将序列化后的tf.train.Example对象存入文件中.
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):# 写入三次
        writer.write(serialized_example)

五,使用tf.data的API读取tf_record文件

dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)

六,将tf.train.Example对象反序列化

# 定义每个feature的数据类型
expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string), # VarLenFeature代表变长的数据类型
    "hours": tf.io.VarLenFeature(dtype = tf.float32), 
    "age": tf.io.FixedLenFeature([], dtype = tf.int64), # FixedLenFeature代表定长的数据类型
}
# 解析TFRecord数据集
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    # print(example)
    books = tf.sparse.to_dense(example["favorite_books"],default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))

七,将tf.train.Example对象存入压缩文件中,生成的tf_record压缩文件

filename_fullpath_zip = filename_fullpath + '.zip'
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

八,使用tf.data的API读取tf_record压缩文件

dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip],compression_type= "GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值