Tensorflow datasets.shuffle repeat batch方法的妙用,以及batch是否会重复取值

最普通的情况

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(4):
        value = sess.run(next_element)
        print(value)

输出结果为:

[0 1 2 3 4 5 6 7]
[ 8  9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]

可以看到,最后一行的输出只有6个数,说明batch方法还是很智能的,并且在这种情况下batch方法取出来的数据并不会存在重复的情况。

如果将迭代次数增加为5次呢?

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

# 输出结果:
# outOfRangeError (see above for traceback): End of sequence

此时可以看到会出现outOfRangeError错误,但是我们有时候又不想去计算每次到底循环多少次才能恰好的用batch方法取完数据。这个时候repeate方法就有用了。

# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出结果为:

[0 1 2 3 4 5 6 7]
[ 8  9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]

可以看到batch方法已经正确的重复取出数据了,repeate方法当中的参数count代表的是最多将数据重复几次,如果设置了count参数,但是循环的次数依然过大,那么还是会报错,所以为了保险起见,最好是不传入参数,或者count=None。

dataset.shuffle方法的使用

shuffle方法用于打乱数据之后取出,它的可设置参数buffer_size代表的是缓冲区的大小。buffer_size的大小越大代表取出的数据混乱程度越高,其中buffer_size=1,代表有序状态,此时和batch方法一样。

# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).shuffle(buffer_size=100).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出结果:

[ 6  4 11 24 15 26 18 27]
[12 28  5 25  0 19 14 20]
[17 10  7  8  3 21 13  2]
[22 16 29  9 23  1]
[11 12 14  5  3  0 26 22]

注意:shuffle的顺序很重要,一般建议是先执行shuffle方法,接着采用batch方法,这样是为了保证在整体数据打乱之后再取出batch_size大小的数据。如果先采取batch方法再采用shuffle方法,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。示例如下:

将上一段代码的batch方法放到shuffle方法的前面

# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8).shuffle(buffer_size=100)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(5):
        value = sess.run(next_element)
        print(value)

输出:

[16 17 18 19 20 21 22 23]
[ 8  9 10 11 12 13 14 15]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]
[24 25 26 27 28 29]

可以看出,数据依然是有序的,起不到打乱的作用。

也可参考博客:参考博客

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值