tensorflow dataset 基础之——tfRecord

目录

1. tfRecord介绍

2. tf.record文件

2.1 生成一个tfRecord文件

2.2 读取tfRecord文件

3. tfRecord压缩文件

3.1 将tfRecord 存成压缩文件

4. tfRecord实战


tf.record是tensorflow中独有的一个格式,故其有很多优势,在读取数据方面,tf.records有着速度快的优势

1. tfRecord介绍

-> tf.train.Example.
-> tf.train.Features -> {"key": tf.train.Feature}.
-> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List

favorite_books = [name.encode('utf-8') 
                 for name in ["machine learning", "cc150"]]

# ByteList
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

# FloatList
hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

# Int64List
age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)

# tf.trian.Features
# tf.train.Feature
features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(
            float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list)
    }
)

print(features)
# tf.train.example
# An Example is a mostly-normalized data format for storing data for training and inference.
example = tf.train.Example(features=features)
print(example)

通过序列化,对其进行压缩,以减少存储空间。

serialized_example = example.SerializeToString()
print(serialized_example)

2. tf.record文件

2.1 生成一个tfRecord文件

import os
output_dir = '/content/drive/MyDrive/data/tfrecord_data'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

filename = "test.tfrecords"
filename_fullpath = os.path.join(output_dir, filename)
with tf.io.TFRecordWriter(filename_fullpath) as writer:
    for i in range(3):
        writer.write(serialized_example)

2.2 读取tfRecord文件

dataset= tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    print(serialized_example)

将序列化之后的example解析成序列化之前的example

# 定义一个字典,定义每一个feature所对应的类型
excepted_features = {
    "favorite_books": tf.io.VarLenFeature(dtype = tf.string),
    "hours": tf.io.VarLenFeature(dtype = tf.float32),
    "age": tf.io.FixedLenFeature([], dtype = tf.int64)
}

dataset = tf.data.TFRecordDataset([filename_fullpath])
for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        excepted_features)
    print(example)

由以上可以看出,生成的是spare_tensor. sparse_tensor在存稀疏矩阵时,效率比较高

解析 spare_tensor

for serialized_example_tensor in dataset:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        excepted_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("utf-8"))

3. tfRecord压缩文件

3.1 将tfRecord 存成压缩文件

filename_zip = "test.tfrecords.zip"
filename_fullpath_zip = os.path.join(output_dir, filename_zip)
options = tf.io.TFRecordOptions(compression_type = "GZIP")
with tf.io.TFRecordWriter(filename_fullpath_zip, options) as writer:
    for i in range(3):
        writer.write(serialized_example)

读取存储的压缩文件

dataset_zip= tf.data.TFRecordDataset([filename_fullpath_zip], 
                                     compression_type="GZIP")
for serialized_example_tensor in dataset_zip:
    example = tf.io.parse_single_example(
        serialized_example_tensor,
        excepted_features)
    books = tf.sparse.to_dense(example["favorite_books"],
                               default_value=b"")
    for book in books:
        print(book.numpy().decode("utf-8"))

4. tfRecord实战

 函数——从csv中读出dataset

import numpy as np
import functools

def parse_csv_line(line, n_fields):
    defs = [tf.constant(np.nan)] * n_fields
    parse_fields = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parse_fields[0:-1])
    y = tf.stack(parse_fields[-1:])
    return x, y

# 使用functools.partial,把一个函数的某些参数给固定住(当然,也可以简单设定parse_csv_line中,n_fields=9)
parse_csv_line_9 = functools.partial(parse_csv_line, n_fields = 9)

def csv_reader_dataset(filenames, n_readers=5, batch_size=32, 
                       n_parse_threads=5, shuffle_buffer_size=10000):
    filename_dataset = tf.data.Dataset.list_files(filenames)
    filename_dataset = filename_dataset.repeat()
    dataset = filename_dataset.interleave(
        lambda filename: tf.data.TextLineDataset(filename).skip(1),
        cycle_length = n_readers
    )

    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line_9,
                          num_parallel_calls = n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

函数——分别对tf.dataset进行遍历,把其数据写入到tf.record文件中

def serialized_example(x, y):
    """converts x, y, to tf.train.Example and serialize"""
    input_features = tf.train.FloatList(value = x)
    label = tf.train.FloatList(value = y)

    features = tf.train.Features(
        feature = {
            "input_features": tf.train.Feature(float_list = input_features),
            "label": tf.train.Feature(float_list = label)
        }
    )

    example = tf.train.Example(features = features)
    return example.SerializeToString()

函数——将csv dataset 写入到 tf records

def csv_dataset_to_tfrecords(base_filename_dir, dataset, n_shards, steps_per_shard, 
                             compression_type = None):
    """
    :parms base_filename: 
    :parms dataset: csv dataset
    :parms n_shards: 将dataset存成多少个文件
    :parms steps_per_shard: 对于每个小文件,应该在dataset走多少步。
                            因为在构建dataset的时候用了repeat,dataset的遍历永远不会结束,
                            故应该算一下去遍历多少步
    :parms compression_type: 压缩类型,比如"GZIP",None,表示不压缩
    """
    if not os.path.exists(base_filename_dir):
        os.mkdir(base_filename_dir)
    options = tf.io.TFRecordOptions(compression_type = compression_type)
    all_filenames = []
    for shard_id in range(n_shards):
        filename = '{:05d}-of-{:05d}'.format(shard_id, n_shards)
        filename_fullpath = os.path.join(base_filename_dir, filename)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.take(steps_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(serialized_example(x_example, y_example))
    
        all_filenames.append(filename_fullpath)
    return all_filenames

实战——(1)从csv文件中读数据,得到dataset

              (2)将dataset数据存入到tfrecord文件,并返回tfrecord文件的文件名

n_shards = 20
batch_size = 32
train_steps_per_shard = 11610 // batch_size // n_shards
valid_steps_per_shard = 3880 // batch_size // n_shards
test_steps_per_shard = 5170 // batch_size // n_shards

output_dir = "/content/drive/MyDrive/data/generate_tfrecords"
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_base_dir = os.path.join(output_dir, "train")
valid_base_dir = os.path.join(output_dir, "valid")
test_base_dir = os.path.join(output_dir, "test")

csv_dir = "/content/drive/MyDrive/data/generate_csv"

# 得到csv数据的训练集,验证集,测试集的文件列表
def get_filenames(dir, prefix):
    filenames_dir = os.path.join(csv_dir, "train")
    filenames_list = os.listdir(filenames_dir)
    return [os.path.join(filenames_dir, e) for e in filenames_list]


train_csv_filenames = get_filenames(csv_dir, "train")
valid_csv_filenames = get_filenames(csv_dir, "valid")
test_csv_filenames = get_filenames(csv_dir, "test")


train_set = csv_reader_dataset(train_csv_filenames, n_readers=5, batch_size=32, 
    n_parse_threads=5, shuffle_buffer_size=10000)
valid_set = csv_reader_dataset(valid_csv_filenames, n_readers=5, batch_size=32, 
    n_parse_threads=5, shuffle_buffer_size=10000)
test_set = csv_reader_dataset(test_csv_filenames, n_readers=5, batch_size=32, 
    n_parse_threads=5, shuffle_buffer_size=10000)


train_tfrecord_filenames = csv_dataset_to_tfrecords(
    train_base_dir, train_set, n_shards, train_steps_per_shard, None)
valid_tfrecord_filenames = csv_dataset_to_tfrecords(
    valid_base_dir, valid_set, n_shards, valid_steps_per_shard, None)
test_tfrecord_filenames = csv_dataset_to_tfrecords(
    test_base_dir, test_set, n_shards, test_steps_per_shard, None)

读取tf.record文件,得到dataset

expected_features= {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

# 解析序列化的example
def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example, expected_features)

    return example["input_features"], example["label"]
def tfrecords_reader_dataset(filenames, n_readers=5, batch_size=32, 
                             n_parse_threads=5, shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(filename, 
                                                 compression_type=None),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

举个简单的例子,来看一下我们存入的数据是否正确

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size=3)
for x_batch, y_batch in tfrecords_train.take(2):
    print(x_batch)
    print(y_batch)

实战——利用tfrecords中的数据生成训练集,验证集,测试集

batch_size= 32
tfrecords_train_set = tfrecords_reader_dataset(train_tfrecord_filenames, 
                                               batch_size=batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(valid_tfrecord_filenames, 
                                               batch_size=batch_size)
tfrecords_test_set = tfrecords_reader_dataset(test_tfrecord_filenames, 
                                               batch_size=batch_size)

训练模型,并测试

from tensorflow import keras

model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(15, activation='relu'),
    keras.layers.Dense(1),
])

model.compile(loss=keras.losses.MeanSquaredError(), 
              optimizer=keras.optimizers.Adam(learning_rate=1e-3),
              metrics=["accuracy"])

callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

history = model.fit(train_set,
                    validation_data = valid_set,
                    steps_per_epoch = 11160 // batch_size, # 11160 为训练集的样本数
                    validation_steps = 3870 // batch_size, # 3870 为验证集的样本数
                    epochs = 100,
                    callbacks = callbacks)
model.evaluate(test_set, steps = 5160 // batch_size) # 5160表示测试集的总样本的个数

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: TensorFlow DatasetTFDS)是一个用于构建高效、可重复使用的数据管道的库。它提供了一些预先处理好的数据集,同时也支持用户自己导入自定义数据集,并可以在数据集上应用各种转换操作,例如 shuffle、batch、map、filter 等等。使用 TensorFlow Dataset 可以帮助用户更加方便地处理数据,并提高数据处理的效率。 ### 回答2: TensorFlow Dataset是一种高性能、易用、可重复的数据输入管道工具,用于处理大规模的训练和验证数据集。TensorFlow Dataset支持多种类型的数据源,如TensorFlow中的张量、numpy数组、Python生成器、CSV文件等等,并提供了一系列数据变换操作,例如shuffle、batch、map和repeat等等,有效地减少数据预处理的代码量。TensorFlow Dataset还支持多线程和预取数据操作,可以大幅度提高数据输入的效率。 使用TensorFlow Dataset有以下优点: 1.性能高:TensorFlow Dataset很好地利用了硬件资源,提供了高效的数据输入管道,极大地提高了训练效率。 2.处理数据方便:TensorFlow Dataset提供了一系列数据变换操作,方便地处理数据。 3.易用:TensorFlow Dataset简单易懂,并且有很多示例可以参考。 4.可复制:TensorFlow Dataset的数据输入是可重复的,保证了实验结果的可复现性。 在使用TensorFlow Dataset时,需要先将数据转换成tf.data.Dataset类型,然后使用map、batch、shuffle等方法进行数据处理,最后以迭代器的形式读取数据进行训练。TensorFlow Dataset的优点在于提供了一种易于使用,高效灵活的数据处理工具,可以大幅度降低数据预处理的代码量,同时保证训练效率和实验结果的可复现性,适用于大规模深度学习训练及推理。 ### 回答3: TensorFlow Dataset是Google开发的一种灵活、高效的数据载入工具,它是TensorFlow官方推荐的载入数据的方法之一。使用TensorFlow Dataset可以实现对大型数据集进行高效、快速的处理,同时也可以方便地进行数据预处理和输入函数的编写。 TensorFlow Dataset支持多种数据源,如numpy数组、csv文件、TFRecord文件、文本文件等。同时,它也支持对数据进行变换、扩充、重复、分片等操作,方便进行数据预处理。在数据输入时,TensorFlow Dataset可以自动进行多线程读取,提高数据输入的效率。 TensorFlow Dataset同时也支持多种数据集的处理操作,如shuffle、batch、repeat、map等。这些操作可以方便地实现数据集的乱序、分批、数据增强等操作。同时,TensorFlow Dataset还提供了一种方便的函数tf.data.Iterator,可以方便地实现对数据集的遍历。 TensorFlow Dataset的使用可以提高训练效率、降低内存消耗、方便数据预处理等,因此在TensorFlow的开发中得到广泛的使用。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值