最普通的情况
# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(4):
value = sess.run(next_element)
print(value)
输出结果为:
[0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]
可以看到,最后一行的输出只有6个数,说明batch方法还是很智能的,并且在这种情况下batch方法取出来的数据并不会存在重复的情况。
如果将迭代次数增加为5次呢?
# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
value = sess.run(next_element)
print(value)
# 输出结果:
# outOfRangeError (see above for traceback): End of sequence
此时可以看到会出现outOfRangeError错误,但是我们有时候又不想去计算每次到底循环多少次才能恰好的用batch方法取完数据。这个时候repeate方法就有用了。
# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
value = sess.run(next_element)
print(value)
输出结果为:
[0 1 2 3 4 5 6 7]
[ 8 9 10 11 12 13 14 15]
[16 17 18 19 20 21 22 23]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]
可以看到batch方法已经正确的重复取出数据了,repeate方法当中的参数count代表的是最多将数据重复几次,如果设置了count参数,但是循环的次数依然过大,那么还是会报错,所以为了保险起见,最好是不传入参数,或者count=None。
dataset.shuffle方法的使用
shuffle方法用于打乱数据之后取出,它的可设置参数buffer_size代表的是缓冲区的大小。buffer_size的大小越大代表取出的数据混乱程度越高,其中buffer_size=1,代表有序状态,此时和batch方法一样。
# 创建0-30的数据集,每个batch取个数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).shuffle(buffer_size=100).batch(8)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
value = sess.run(next_element)
print(value)
输出结果:
[ 6 4 11 24 15 26 18 27]
[12 28 5 25 0 19 14 20]
[17 10 7 8 3 21 13 2]
[22 16 29 9 23 1]
[11 12 14 5 3 0 26 22]
注意:shuffle的顺序很重要,一般建议是先执行shuffle方法,接着采用batch方法,这样是为了保证在整体数据打乱之后再取出batch_size大小的数据。如果先采取batch方法再采用shuffle方法,那么此时就只是对batch进行shuffle,而batch里面的数据顺序依旧是有序的,那么随机程度会减弱。示例如下:
将上一段代码的batch方法放到shuffle方法的前面
# 创建0-30的数据集,每个batch取个8数。
import tensorflow as tf
dataset = tf.data.Dataset.range(30).batch(8).shuffle(buffer_size=100)
dataset=dataset.repeat(count=None)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(5):
value = sess.run(next_element)
print(value)
输出:
[16 17 18 19 20 21 22 23]
[ 8 9 10 11 12 13 14 15]
[24 25 26 27 28 29]
[0 1 2 3 4 5 6 7]
[24 25 26 27 28 29]
可以看出,数据依然是有序的,起不到打乱的作用。
也可参考博客:参考博客