Reference
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle
buffer_size的含义——Dataset.map , Dataset.prefetch and Dataset.shuffle_Eartha1995的博客-CSDN博客_buffer_size
本文就不重复上面3篇文章的代码了。免得大家又用batch()或者next()的角度去解构它,虽然这些解构方式都是正确的。
一、直观生动的代码案例
话不多说,上代码!
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(2)
print(list(dataset.as_numpy_iterator()))
上面这个代码块,输出为
[1, 2, 3, 0, 5, 6, 7, 4, 8, 9]
而如果把buffer size调为10,输出明显更混乱(如下所见)。
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(10)
print(list(dataset.as_numpy_iterator()))
# [3, 5, 4, 0, 2, 6, 1, 7, 9, 8]
因此,如果你不想要懂原理,想要真正地、有效地用到这个shuffle方法,直接把size设置为整个数据集大小,或者成倍大小(感觉没必要)。
即使这样确实会增大内存消耗,但是不这样做,打乱效果会很差。
二、详细贴心的原理讲解
首先讲buffer size为2的情况。
输入
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(2)
print(list(dataset.as_numpy_iterator()))
输出
[1, 2, 3, 0, 5, 6, 7, 4, 8, 9]
这个输出怎么得到的?这真的是手把手、一步步打出来的呀!不能再详细了!
1、按顺序在dataset中先取出buffer size大小的数据,而后存入buffer缓冲区;第一次取的是[0,1]。
2、在缓冲区随机取出一个元素输出到output区中去;第一次输出的是1,所以output的第一个元素是1;于是缓冲区变成了[0, ]。
3、接着,再按顺序把dataset中的2放入buffer缓冲区;于是缓冲区变成了[0, 2]。
4、然后再随机从buffer取一个元素输出,第二次输出2;于是缓冲区变成了[0, ];output的第二个元素是2。
5、接着,把dataset中的3放入buffer缓冲区;于是缓冲区变成了[0, 3]。
6、然后再随机从buffer取一个元素输出,第三次输出3;于是缓冲区变成了[0, ];output的第三个元素是3。
7、接着,把dataset中的4放入buffer缓冲区;于是缓冲区变成了[0, 4]。
8、然后再随机从buffer取一个元素输出,第四次输出0;于是缓冲区变成了[ , 4];output的第四个元素是0。
9、接着,把dataset中的5放入buffer缓冲区;于是缓冲区变成了[5, 4]。
10、然后再随机从buffer取一个元素输出,第五次输出5;于是缓冲区变成了[ , 4];output的第五个元素是5。
11、接着,把dataset中的6放入buffer缓冲区;于是缓冲区变成了[6, 4]。
12、然后再随机从buffer取一个元素输出,第六次输出6;于是缓冲区变成了[ , 4];output的第六个元素是6。
13、接着,把dataset中的7放入buffer缓冲区;于是缓冲区变成了[7, 4]。
14、然后再随机从buffer取一个元素输出,第七次输出7;于是缓冲区变成了[ , 4];output的第七个元素是7。
15、接着,把dataset中的8放入buffer缓冲区;于是缓冲区变成了[8, 4]。
16、然后再随机从buffer取一个元素输出,第八次输出4;于是缓冲区变成了[8, ];output的第八个元素是4。
17、接着,把dataset中的9放入buffer缓冲区;于是缓冲区变成了[8, 9]。
18、然后再随机从buffer取一个元素输出,第九次输出8;于是缓冲区变成了[ , 9];output的第九个元素是8。
19、最后,buffer区发现dataset区已经被他“掏空了”(可怜巴巴...),于是只好无可奈何用小于buffer size的剩下的元素进行随机输出;第十次输出啥呢?笨蛋,只能输出9 了。
20、于是output区的最后一个元素是9。
三、结合batch进一步理解shuffle
3.1 当reshuffle_each_iteration=None时
如果这个参数不设置,也即为None时,和它为True时效果竟然一样(俺也不知道为啥...不过既然如此,建议手动设置为False)。
如果这个参数为True,那么每一次迭代(比如for循环、list()/tuple()方法、iter()方法)的时候,都会重新乱序。
这样也就导致每一次的迭代输出,结果不一样。
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3)
print(list(dataset.as_numpy_iterator()))
dataset = dataset.batch(2)
print(list(dataset))
for ds in dataset:
print(ds.numpy())
# [0, 2, 1, 4, 5, 3, 8, 7, 9, 6]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 3], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 2], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([7, 6], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([9, 8], dtype=int64)>]
# [1 2]
# [0 5]
# [4 3]
# [8 7]
# [9 6]
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3,reshuffle_each_iteration=True)
print(list(dataset.as_numpy_iterator()))
dataset = dataset.batch(2)
print(list(dataset))
for ds in dataset:
print(ds.numpy())
# [1, 2, 4, 0, 3, 6, 7, 9, 5, 8]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 8], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([9, 2], dtype=int64)>]
# [1 3]
# [0 2]
# [5 4]
# [7 9]
# [6 8]
3.2 当reshuffle_each_iteration=False时
dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3,reshuffle_each_iteration=False)
print(list(dataset.as_numpy_iterator()))
dataset = dataset.batch(2)
print(list(dataset))
for ds in dataset:
print(ds.numpy())
# 输出如下:
# [0, 2, 1, 5, 4, 7, 3, 9, 6, 8]
# <BatchDataset shapes: (None,), types: tf.int64>
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 2], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 9], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 8], dtype=int64)>]
# [0 2]
# [1 5]
# [4 7]
# [3 9]
# [6 8]
3.3 batch本身是不会乱序的,以原序按批次分割
dataset = tf.data.Dataset.range(10)
# dataset = dataset.shuffle(3)
print(list(dataset.as_numpy_iterator()))
dataset = dataset.batch(2)
print(list(dataset))
for ds in dataset:
print(ds.numpy())
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 3], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([8, 9], dtype=int64)>]
# [0 1]
# [2 3]
# [4 5]
# [6 7]
# [8 9]