def read_and_decode_TFRecordDataset_deblur(tfrecords_path, batch_size, epoch_num):
dataset = tf.data.TFRecordDataset(tfrecords_path)
dataset = dataset.map(parser_deblur).shuffle(buffer_size=100*batch_size)
epoch = tf.data.Dataset.range(epoch_num)
dataset = epoch.flat_map(lambda i: tf.data.Dataset.zip(
(dataset, tf.data.Dataset.from_tensors(i).repeat())))
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
(face_blur_batch, face_gt_batch), epochNow = iterator.get_next()
return face_blur_batch, face_gt_batch, epochNow
其中有2点比较重要:
1. 一定要在指定epoch之前shuffle()
REFERENCE: How to set epoch counter when using TFRecordDataset?
2. shuffle的buffer_size如何设置
REFERENCE: Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle
REFERENCE: tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解