【TensorFlow】如何判断当前进行到第几个epoch了?

def read_and_decode_TFRecordDataset_deblur(tfrecords_path, batch_size, epoch_num):
    dataset = tf.data.TFRecordDataset(tfrecords_path)
    dataset = dataset.map(parser_deblur).shuffle(buffer_size=100*batch_size)
    epoch = tf.data.Dataset.range(epoch_num)
    dataset = epoch.flat_map(lambda i: tf.data.Dataset.zip(
        (dataset, tf.data.Dataset.from_tensors(i).repeat())))
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_one_shot_iterator()
    (face_blur_batch, face_gt_batch), epochNow = iterator.get_next()
    return face_blur_batch, face_gt_batch, epochNow

其中有2点比较重要:

1. 一定要在指定epoch之前shuffle()

REFERENCE: How to set epoch counter when using TFRecordDataset?

2. shuffle的buffer_size如何设置

REFERENCE: Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle

REFERENCE: tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

 

  • 4
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值