我正在使用tf.data.TFRecordDataset从TFRecord文件读取数据集。在
我想知道每一步都在处理哪个时代。在
详细信息如下:100个样本保存在一个TFRecord文件中,batch_size设置为50,epoch_num设置为5。在
下面是我的简化代码:def read_and_decode_TFRecordDataset(tfrecords_path, batch_size, epoch_num):
dataset = tf.data.TFRecordDataset(tfrecords_path)
dataset = dataset.map(parser_deblur)
epoch = tf.data.Dataset.range(epoch_num)
dataset = epoch.flat_map(lambda i: tf.data.Dataset.zip(
(dataset, tf.data.Dataset.from_tensors(i).repeat())))
dataset = dataset.repeat(epoch_num).shuffle(1000).batch(batch_size)
iterator = dataset.make_one_shot_iterator()
(face_blur_batch, face_gt_batch), epochNow = iterator.get_next()
return face_blur_batch, face_gt_batch, epochNow
print EPOCH: {epochNow}, STEP: {step}
我期望的是:
^{pr2}$
但实际产出是:
^{3}$
我不知道什么是纪元?这似乎是随机的。每次跑步都不一样。在
你知道如何修复上面的代码吗?或者如何通过其他方法获得纪元计数器?在