TensorFlow2.X——读取tfrecord文件数据应用到tf.Keras模型中

读取tfrecord文件数据应用到tf.Keras模型中

代码示例:

#tfrecord文件展示
pprint.pprint(train_tfrecord_filenames)
pprint.pprint(vaild_tfrecord_filenames)
pprint.pprint(test_tfrecord_filenames)

[‘generate_tfrecords\train_00000-of-00020’,
‘generate_tfrecords\train_00001-of-00020’,
‘generate_tfrecords\train_00002-of-00020’,]
[‘generate_tfrecords\vaild_00000-of-00020’,
‘generate_tfrecords\vaild_00001-of-00020’,
‘generate_tfrecords\vaild_00002-of-00020’,
‘generate_tfrecords\vaild_00019-of-00020’]
[‘generate_tfrecords\test_00000-of-00020’,
‘generate_tfrecords\test_00001-of-00020’,
‘generate_tfrecords\test_00002-of-00020’,]

#定义一个解析example类型的字典
expected_features = {
    "input_features" : tf.io.FixedLenFeature([8], dtype=tf.float32),
    "label" : tf.io.FixedLenFeature([1], dtype=tf.float32)
}

def parse_example(serialized_example):
    example = tf.io.parse_single_example(serialized_example, expected_features)
    return example["input_features"], example["label"]
#定义读取csv文件形成一个Dataset
#n_reader : 并行读取文件数
#n_parse_threads : 解析文件时的并行数
#shuffle_buffer_size : 混排buffe的大小
def tfrecord_reader_dataset(filenames, n_reader=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000 ):
    dataset = tf.data.Dataset.list_files(filenames)
    #repeat(): 无参数表示重复无数次
    #作用:在训练模型时我们不止一次使用数据,要多次使用训练集数据,通过epoch来终止
    dataset = dataset.repeat()
    #interleave() : 读取数据形成一个dataset
    dataset = dataset.interleave(
        lambda filename: tf.data.TFRecordDataset(filename),
        cycle_length = n_reader
    )
    dataset.shuffle(shuffle_buffer_size)
    #map():映射到tf.io.decode_csv()函数,解析数据
    dataset = dataset.map(parse_example, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

batch_size = 32
tfrecords_train_set = tfrecord_reader_dataset(train_tfrecord_filenames, batch_size=batch_size)
tfrecords_valid_set = tfrecord_reader_dataset(vaild_tfrecord_filenames, batch_size=batch_size)
tfrecords_test_set = tfrecord_reader_dataset(test_tfrecord_filenames, batch_size=batch_size)
#结合Keras,定义网络进行训练数据

#使用序贯模型Sequential   tf.keras.models.sequential()

model = keras.models.Sequential([
    #keras.layers.Flatten(input_shape = x_train.shape[1:]),如果数据已经展平,真不用再使用flatten。
    keras.layers.Dense(30, activation="relu",input_shape = [8]),
    keras.layers.Dense(1),
])

#编译compile
model.compile(loss = "mean_squared_error",   #损失函数:使用均方根误差
             optimizer = "sgd", #优化函数 
             ) 

#使用回调函数
callbacks = [
    keras.callbacks.EarlyStopping(patience=5, min_delta=1e-3),
]

#训练模型会,返回一个结果保存在history中
history = model.fit(tfrecords_train_set,
                    validation_data = tfrecords_valid_set,
                    steps_per_epoch = 11160 // batch_size,
                    validation_steps = 3870 // batch_size,
                    epochs=50,
                    callbacks = callbacks) 


Train for 348 steps, validate for 120 steps
Epoch 1/50
348/348 [==============================] - 2s 7ms/step - loss: 0.8175 - val_loss: 0.6190
Epoch 2/50
348/348 [==============================] - 1s 3ms/step - loss: 0.5461 - val_loss: 0.5317
Epoch 3/50
348/348 [==============================] - 1s 3ms/step - loss: 0.5030 - val_loss: 0.5068
Epoch 4/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4651 - val_loss: 0.4862
Epoch 5/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4606 - val_loss: 0.4791
Epoch 6/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4475 - val_loss: 0.4763
Epoch 7/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4472 - val_loss: 0.4687
Epoch 8/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4308 - val_loss: 0.4683
Epoch 9/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4329 - val_loss: 0.4626
Epoch 10/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4227 - val_loss: 0.4671
Epoch 11/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4220 - val_loss: 0.4711
Epoch 12/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4075 - val_loss: 0.4541
Epoch 13/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4067 - val_loss: 0.4576
Epoch 14/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4046 - val_loss: 0.4541
Epoch 15/50
348/348 [==============================] - 1s 3ms/step - loss: 0.4005 - val_loss: 0.4542
Epoch 16/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3957 - val_loss: 0.4489
Epoch 17/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3945 - val_loss: 0.5035
Epoch 18/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3926 - val_loss: 0.4486
Epoch 19/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3825 - val_loss: 0.4390
Epoch 20/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3801 - val_loss: 0.4349
Epoch 21/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3805 - val_loss: 0.4497
Epoch 22/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3763 - val_loss: 0.4339
Epoch 23/50
348/348 [==============================] - 1s 4ms/step - loss: 0.3726 - val_loss: 0.4384
Epoch 24/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3690 - val_loss: 0.4317
Epoch 25/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3641 - val_loss: 0.4408
Epoch 26/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3656 - val_loss: 0.4304
Epoch 27/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3681 - val_loss: 0.4270
Epoch 28/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3613 - val_loss: 0.4374
Epoch 29/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3555 - val_loss: 0.4307
Epoch 30/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3522 - val_loss: 0.4353
Epoch 31/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3527 - val_loss: 0.4312
Epoch 32/50
348/348 [==============================] - 1s 3ms/step - loss: 0.3489 - val_loss: 0.4315
model.evaluate(tfrecords_test_set, steps = 5160 // batch_size)
  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 21
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 21
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值