【ML&DL学习】12 tf tfrecord basic api generate tfrecord

本文介绍如何将dataset转换为TFRecord格式,并利用tf.data进行读取,适用于机器学习和深度学习模型的训练。内容包括TFRecord文件的生成、读取、序列化与压缩操作。
摘要由CSDN通过智能技术生成

先把dataset转出tfrecord格式,再用tf.data进行读取,然后模型进行训练

# tfrecord 文件格式
# -> tf.train.Example
#    -> tf.train.Features -> {"key": tf.train.Feature}
#       -> tf.train.Feature -> tf.train.ByteList/FloatList/Int64List

favorite_books = [name.encode('utf-8')
                  for name in ["machine learning", "cc150"]]
favorite_books_bytelist = tf.train.BytesList(value = favorite_books)
print(favorite_books_bytelist)

hours_floatlist = tf.train.FloatList(value = [15.5, 9.5, 7.0, 8.0])
print(hours_floatlist)

age_int64list = tf.train.Int64List(value = [42])
print(age_int64list)

features = tf.train.Features(
    feature = {
        "favorite_books": tf.train.Feature(
            bytes_list = favorite_books_bytelist),
        "hours": tf.train.Feature(
            float_list = hours_floatlist),
        "age": tf.train.Feature(int64_list = age_int64list),
    }
)
print(features)

在这里插入图片描述
example 与 序列化(压缩减小size)
在这里插入图片描述
保存tfrecord文件
在这里插入图片描述
在这里插入图片描述
读取
在这里插入图片描述
反序列化
在这里插入图片描述
tfrecord压缩
在这里插入图片描述
在这里插入图片描述
压缩后读取tfrecord
在这里插入图片描述

tfrecord 生成 读取 使用

在这里插入图片描述

source_dir = "./generate_csv/"

def get_filenames_by_prefix(source_dir, prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results

train_filenames = get_filenames_by_prefix(source_dir, "train")
valid_filenames = get_filenames_by_prefix(source_dir, "valid")
test_filenames = get_filenames_by_prefix(source_dir, "test")

import pprint
pprint.pprint(train_filenames)
pprint.pprint(valid_filenames)
pprint.pprint(test_filenames)

在这里插入图片描述
读取csv装成dataset
在这里插入图片描述
读取dataset进行遍历,把取到的数据写到tfrecord文件里去
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
读取tfrecord

expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]

def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(
            filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(2):
    print(x_batch)
    print(y_batch)

在这里插入图片描述
生成训练中使用的数据集

batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size = batch_size)

训练

model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

history = model.fit(tfrecords_train_set,
                    validation_data = tfrecords_valid_set,
                    steps_per_epoch = 11160 // batch_size,
                    validation_steps = 3870 // batch_size,
                    epochs = 100,
                    callbacks = callbacks)

在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值