实战tfrecord文件的生成与读取

1、tfrecord基础API使用

tfrecord文件格式

         -> tf.train.Example
                            -> tf.train.Features -> dict{"key": tf.train.Feature}
                                                -> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List

 (1)定义tf.train.ByteList/FloatList/Int64List

favorite_books = [name.encode('utf-8')
                  for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)
value: "machine learning"
value: "cc150"

value: 15.5
value: 9.5
value: 7.0
value: 8.0

value: 42

(2)定义features

features = tf.train.Features(
    feature = {
        "favorite_books_bytelist": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours_floatlist": tf.train.Feature(
            float_list = hours_floatlist),
        "age_int64list": tf.train.Feature(
            int64_list = age_int64list),
    }
)
print(features)
feature {
  key: "age"
  value {
    int64_list {
      value: 42
    }
  }
}
feature {
  key: "favorite_books"
  value {
    bytes_list {
      value: "machine learning"
      value: "cc150"
    }
  }
}
feature {
  key: "hours"
  value {
    float_list {
      value: 15.5
      value: 9.5
      value: 7.0
      value: 8.0
    }
  }
}

(3)定义Example

example = tf.train.Example(features=features)
print(example)
features {
  feature {
    key: "age"
    value {
      int64_list {
        value: 42
      }
    }
  }
  feature {
    key: "favorite_books"
    value {
      bytes_list {
        value: "machine learning"
        value: "cc150"
      }
    }
  }
  feature {
    key: "hours"
    value {
      float_list {
        value: 15.5
        value: 9.5
        value: 7.0
        value: 8.0
      }
    }
  }
}

(4)将example序列化

因为在存储时需要对内容进行压缩,以减少size。

serialized_example = example.SerializeToString()
print(serialized_example)
b'\n\\\n-\n\x0efavorite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A'

(5)生成tfrecord文件

将序列化后的example存到文件中,生成一个tfrecord文件

output_dir = 'tfrecord_basic'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename) #全路径:文件夹+文件名
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)#将序列化后的example写进去三次

(6)读取TFRecord文件

1)读取序列化后的example的TFRecord文件

dataset = tf.data.TFRecordDataset([filename_fullpath])#生成TFRecord的dateset
for serialized_example_tensor in dataset:
    print(serialized_example_tensor)
tf.Tensor(b'\n\\\n-\n\x0efavorite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A', shape=(), dtype=string)
tf.Tensor(b'\n\\\n-\n\x0efavorite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A', shape=(), dtype=string)
tf.Tensor(b'\n\\\n-\n\x0efavorite_books\x12\x1b\n\x19\n\x10machine learning\n\x05cc150\n\x0c\n\x03age\x12\x05\x1a\x03\n\x01*\n\x1d\n\x05hours\x12\x14\x12\x12\n\x10\x00\x00xA\x00\x00\x18A\x00\x00\xe0@\x00\x00\x00A', shape=(), dtype=string)

2)将序列化后的example解析成正常的example,再进行读取

#定义一个字典,字典中定义了每一个feature所对应的类型
expected_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string), #变长
    "hours": tf.io.VarLenFeature(dtype = tf.float32),
    "age": tf.io.FixedLenFeature([], dtype = tf.int64),
}
dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    #print(serialized_example_tensor)
    #对于每个数值并未直接打印,而是先解析一下
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    print(example)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F39D5198>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F39D5278>, 'age': <tf.Tensor: id=164, shape=(), dtype=int64, numpy=42>}
machine learning
cc150
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F3999E48>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F39992B0>, 'age': <tf.Tensor: id=183, shape=(), dtype=int64, numpy=42>}
machine learning
cc150
{'favorite_books': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F39D51D0>, 'hours': <tensorflow.python.framework.sparse_tensor.SparseTensor object at 0x00000125F39D52B0>, 'age': <tf.Tensor: id=202, shape=(), dtype=int64, numpy=42>}
machine learning
cc150

(7)将TFRecord存储成压缩文件(330B vs 127B)

压缩文件与普通文件有很多相同地方,在下面主要介绍不同的地方,其他未介绍的地方,应与普通文件的用法相同。

filename_fullpath_zip = filename_fullpath + '.zip'
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)
dataset_zip = tf.data.TFRecordDataset([filename_fullpath_zip], 
                                      compression_type= "GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        expected_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("UTF-8"))
machine learning
cc150
machine learning
cc150
machine learning
cc150

2、实战-将csv文件转化为tfrecord文件

在csv_reader_dataset中,我们在map函数中对每一行进行解析。

在parse_csv_line中,我们把每一行给拆成了前八个和后一个,即x和y。这样csv_reader_dataset返回的dataset里,每一个batch都是两个元素,即x_batch和y_batch。

而batch解绑定直接用了for循环,如下,

#接下来遍历这三个dataset,把取到的数据放到tfrecord中去
def serialize_example(x, y):
    """Converts x, y to tf.train.Example and serialize"""
    input_feautres = tf.train.FloatList(value = x)
    label = tf.train.FloatList(value = y)
    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(
                float_list = input_feautres),
            "label": tf.train.Feature(float_list = label)
        }
    )
    example = tf.train.Example(features = features)
    return example.SerializeToString()


#将从csv中读取的dataset转化成tf.example,再写到tf.record文件中。
def csv_dataset_to_tfrecords(base_filename, dataset,
                             n_shards, #存成多少个文件
                             steps_per_shard,
                             compression_type = None):
    options = tf.io.TFRecordOptions(
        compression_type = compression_type)
    all_filenames = []
    
    for shard_id in range(n_shards):#遍历每一个要生成的小文件
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
            base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            #for x_batch, y_batch in dataset.take(steps_per_shard):
            for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):#取得的每个数据都是一个batch
                for x_example, y_example in zip(x_batch, y_batch):#解batch
                    writer.write(
                        serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames

生成不压缩的tfrecord文件:

#生成不压缩的tfrecord文件
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards #得到在每一个shard上,需要有多少个batch
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

#输出文件夹
output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard, None)

生成压缩后的tfrecord文件:

#生成压缩后的tfrecord文件
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords_zip"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard,
    compression_type = "GZIP")
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard,
    compression_type = "GZIP")
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard,
    compression_type = "GZIP")

3、实战-读取tfrecord文件

expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]

def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(
            filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(2):
    print(x_batch)
    print(y_batch)
tf.Tensor(
[[-1.1334558   1.0731637  -0.38411045 -0.19008651 -0.45323023 -0.06815852
   1.0292323  -1.3457658 ]
 [-1.0591781   1.3935647  -0.02633197 -0.1100676  -0.6138199  -0.09695935
   0.3247131  -0.03747724]
 [-0.24628098  1.2333642  -0.41765466  0.02003763  0.16009521  0.15687561
  -0.7250671   0.6965625 ]], shape=(3, 8), dtype=float32)
tf.Tensor(
[[1.598]
 [0.672]
 [1.849]], shape=(3, 1), dtype=float32)
tf.Tensor(
[[ 0.51780826  0.6726626   0.15955476 -0.22812782  0.06029374 -0.03868682
   0.94524986 -1.2808508 ]
 [-0.07763786  0.91296333 -0.33707762 -0.27174184 -0.81251556  0.05529061
  -0.6644131   0.5916997 ]
 [-0.3680344  -1.0094423   9.957168    8.32342    -1.1128273  -0.14463872
   1.3418335  -0.21224862]], shape=(3, 8), dtype=float32)
tf.Tensor(
[[1.883]
 [1.652]
 [1.406]], shape=(3, 1), dtype=float32)

附全部代码:

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf

from tensorflow import keras

print(tf.__version__)
print(sys.version_info)
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

source_dir = "./generate_csv/"
print(os.listdir(source_dir))

def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir) #获取该目录下所有的文件名
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))#全路径
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")


#读取csv文件
def parse_csv_line(line, n_fields = 9):
    defs = [tf.constant(np.nan)] * n_fields
    parsed_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_fields[0:-1])
    y = tf.stack(parsed_fields[-1:])
    return x, y


def csv_reader_dataset(filenames, n_readers=5,
                       batch_size=32, n_parse_threads=5,
                       shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
train_set = csv_reader_dataset(train_filenames,
                               batch_size = batch_size)
valid_set = csv_reader_dataset(valid_filenames,
                               batch_size = batch_size)
test_set = csv_reader_dataset(test_filenames,
                              batch_size = batch_size)

#接下来遍历这三个dataset,把取到的数据放到tfrecord中去
def serialize_example(x, y):
    """Converts x, y to tf.train.Example and serialize"""
    input_feautres = tf.train.FloatList(value = x)
    label = tf.train.FloatList(value = y)
    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(
                float_list = input_feautres),
            "label": tf.train.Feature(float_list = label)
        }
    )
    example = tf.train.Example(features = features)
    return example.SerializeToString()


#将从csv中读取的dataset转化成tf.example,再写到tf.record文件中。
def csv_dataset_to_tfrecords(base_filename, dataset,
                             n_shards, #存成多少个文件
                             steps_per_shard,
                             compression_type = None):
    options = tf.io.TFRecordOptions(
        compression_type = compression_type)
    all_filenames = []
    
    for shard_id in range(n_shards):#遍历每一个要生成的小文件
        filename_fullpath = '{}_{:05d}-of-{:05d}'.format(
            base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            #for x_batch, y_batch in dataset.take(steps_per_shard):
            for x_batch, y_batch in dataset.skip(shard_id * steps_per_shard).take(steps_per_shard):#取得的每个数据都是一个batch
                for x_example, y_example in zip(x_batch, y_batch):#解batch
                    writer.write(
                        serialize_example(x_example, y_example))
        all_filenames.append(filename_fullpath)
    return all_filenames

#生成不压缩的tfrecord文件
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards #得到在每一个shard上,需要有多少个batch
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

#输出文件夹
output_dir = "generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard, None)

#生成压缩后的tfrecord文件
n_shards = 20
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "generate_tfrecords_zip"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir, "train")
valid_basename = os.path.join(output_dir, "valid")
test_basename = os.path.join(output_dir, "test")

train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_basename, train_set, n_shards, train_steps_per_shard,
    compression_type = "GZIP")
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_basename, valid_set, n_shards, valid_steps_per_shard,
    compression_type = "GZIP")
test_tfrecord_fielnames = csv_dataset_to_tfrecords(
    test_basename, test_set, n_shards, test_steps_per_shard,
    compression_type = "GZIP")

expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]

def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(
            filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size = batch_size)

model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

history = model.fit(tfrecords_train_set,
                    validation_data = tfrecords_valid_set,
                    steps_per_epoch = 11160 // batch_size,
                    validation_steps = 3870 // batch_size,
                    epochs = 100,
                    callbacks = callbacks)

model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)

 

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Python读取TFRecord文件的方法如下: ```python import tensorflow as tf # 创建一个TFRecordDataset对象 dataset = tf.data.TFRecordDataset('data.tfrecord') # 定义读取函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 应用读取函数到每个record dataset = dataset.map(parser) # 创建迭代器 iterator = dataset.make_one_shot_iterator() # 获取数据 image, label = iterator.get_next() ``` 以上代码演示了如何读取名为`data.tfrecord`的TFRecord文件,并解析其中的图像和标签信息。在解析函数`parser`中,我们先定义了TFRecord文件中包含的特征信息,然后使用`tf.io.parse_single_example`函数解析单个record,并对图像数据进行解码。最后,我们使用`map`函数将解析函数应用到每个record上。 当然,如果您使用的是PyTorch,也可以使用以下代码读取TFRecord文件: ```python import torch import torchvision.datasets as datasets import torchvision.transforms as transforms # 定义解析函数 def parser(record): features = { 'image': tf.io.FixedLenFeature([], dtype=tf.string), 'label': tf.io.FixedLenFeature([], dtype=tf.int64) } parsed = tf.io.parse_single_example(record, features) image = tf.io.decode_jpeg(parsed['image'], channels=3) label = parsed['label'] return image, label # 创建数据集对象 dataset = datasets.DatasetFolder( 'data.tfrecord', loader=lambda x: torch.load(x), extensions=('tfrecord') ) # 应用解析函数到每个record dataset.transform = transforms.Compose([ parser ]) # 创建数据加载器 dataloader = torch.utils.data.DataLoader( dataset, batch_size=32, shuffle=True ) # 获取数据 for images, labels in dataloader: # 使用数据进行训练或预测 pass ``` 以上代码演示了如何使用PyTorch的`DatasetFolder`读取TFRecord文件,并使用解析函数`parser`解析图像和标签信息。最后,我们创建了一个数据加载器,并使用其中的数据进行训练或预测。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值