4-8 tf.data读取tfrecord文件并与tf.keras结合使用

代码承接上一篇

pprint.pprint(train_tfrecord_filenames)
pprint.pprint(valid_tfrecord_filenames)
pprint.pprint(test_tfrecord_fielnames)

首先打印一下我们所生成文件的文件名,下面是其中一个文件名。
‘generate_tfrecords_zip/test_00007-of-00020’
正如在基础API里面提到的,要想解析example,必须要定义解析每个field的字典

expected_features = {
    "input_features": tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label": tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example,
                                         expected_features)
    return example["input_features"], example["label"]

def tfrecords_reader_dataset(filenames, n_readers=5,
                             batch_size=32, n_parse_threads=5,
                             shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.repeat()
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(
            filename, compression_type = "GZIP"),
        cycle_length = n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_example,
                          num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

tfrecords_train = tfrecords_reader_dataset(train_tfrecord_filenames,
                                           batch_size = 3)
for x_batch, y_batch in tfrecords_train.take(2):
    print(x_batch)
    print(y_batch)

首先定义了解析example的字典,机器就知道如何解析这个example。
然后要定义一个map函数,定义如何对每个样本进行处理:parse_example,参数意义是序列化之后的example。先把example解析出来,使用这个函数:tf.io.parse_single_example,参数是序列化的example和指定如何解析它的字典。解析完之后把输入特征和标签返回出去。

然后定义了一个完整的函数,完成从文件名列表到具体的dataset的转变。与读取csv文件的代码比较类似,就不再对读取的过程作讲解,运行之后可以看到,取出来的数据都是正常的。

接下来使用上面定义的函数读取生成训练中使用的数据集。

batch_size = 32
tfrecords_train_set = tfrecords_reader_dataset(
    train_tfrecord_filenames, batch_size = batch_size)
tfrecords_valid_set = tfrecords_reader_dataset(
    valid_tfrecord_filenames, batch_size = batch_size)
tfrecords_test_set = tfrecords_reader_dataset(
    test_tfrecord_fielnames, batch_size = batch_size)

读取了训练集、验证集和测试集,接下来在Keras中使用这些数据。

model = keras.models.Sequential([
    keras.layers.Dense(30, activation='relu',
                       input_shape=[8]),
    keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(
    patience=5, min_delta=1e-2)]

history = model.fit(tfrecords_train_set,
                    validation_data = tfrecords_valid_set,
                    steps_per_epoch = 11160 // batch_size,
                    validation_steps = 3870 // batch_size,
                    epochs = 100,
                    callbacks = callbacks)

训练完成后来用测试集评估一下模型效果。

model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)

到此就完成了tfrecord的实战。我们读取csv文件,转化成为tfrecord文件,再把tfrecord文件读取出来,形成一个数据集,再在tf.Keras中进行使用。需要记住的是:tfrecors是tensorflow独有的一种数据格式,在tf中有很多优化,在读取数据方面有独特的优势

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值