tensorflow中Dataset.shuffle函数的buffer size的含义解读

Reference

tensorflow - Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle - Stack Overflow

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle

buffer_size的含义——Dataset.map , Dataset.prefetch and Dataset.shuffle_Eartha1995的博客-CSDN博客_buffer_size

本文就不重复上面3篇文章的代码了。免得大家又用batch()或者next()的角度去解构它,虽然这些解构方式都是正确的。

一、直观生动的代码案例

话不多说,上代码!

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(2)
print(list(dataset.as_numpy_iterator()))

上面这个代码块,输出为

[1, 2, 3, 0, 5, 6, 7, 4, 8, 9]

而如果把buffer size调为10,输出明显更混乱(如下所见)。

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(10)
print(list(dataset.as_numpy_iterator()))
# [3, 5, 4, 0, 2, 6, 1, 7, 9, 8]

因此,如果你不想要懂原理,想要真正地、有效地用到这个shuffle方法,直接把size设置为整个数据集大小,或者成倍大小(感觉没必要)。 

即使这样确实会增大内存消耗,但是不这样做,打乱效果会很差。

二、详细贴心的原理讲解

首先讲buffer size为2的情况。

输入

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(2)
print(list(dataset.as_numpy_iterator()))

输出

[1, 2, 3, 0, 5, 6, 7, 4, 8, 9]

这个输出怎么得到的?这真的是手把手、一步步打出来的呀!不能再详细了!

1、按顺序在dataset中先取出buffer size大小的数据,而后存入buffer缓冲区;第一次取的是[0,1]。

2、在缓冲区随机取出一个元素输出到output区中去;第一次输出的是1,所以output的第一个元素是1;于是缓冲区变成了[0, ]。

3、接着,再按顺序把dataset中的2放入buffer缓冲区;于是缓冲区变成了[0, 2]。

4、然后再随机从buffer取一个元素输出,第二次输出2;于是缓冲区变成了[0, ];output的第二个元素是2。

5、接着,把dataset中的3放入buffer缓冲区;于是缓冲区变成了[0, 3]。

6、然后再随机从buffer取一个元素输出,第三次输出3;于是缓冲区变成了[0, ];output的第三个元素是3。

7、接着,把dataset中的4放入buffer缓冲区;于是缓冲区变成了[0, 4]。

8、然后再随机从buffer取一个元素输出,第四次输出0;于是缓冲区变成了[ , 4];output的第四个元素是0。

9、接着,把dataset中的5放入buffer缓冲区;于是缓冲区变成了[5, 4]。

10、然后再随机从buffer取一个元素输出,第五次输出5;于是缓冲区变成了[ , 4];output的第五个元素是5。

11、接着,把dataset中的6放入buffer缓冲区;于是缓冲区变成了[6, 4]。

12、然后再随机从buffer取一个元素输出,第六次输出6;于是缓冲区变成了[ , 4];output的第六个元素是6。

13、接着,把dataset中的7放入buffer缓冲区;于是缓冲区变成了[7, 4]。

14、然后再随机从buffer取一个元素输出,第七次输出7;于是缓冲区变成了[ , 4];output的第七个元素是7。

15、接着,把dataset中的8放入buffer缓冲区;于是缓冲区变成了[8, 4]。

16、然后再随机从buffer取一个元素输出,第八次输出4;于是缓冲区变成了[8,  ];output的第八个元素是4。

17、接着,把dataset中的9放入buffer缓冲区;于是缓冲区变成了[8, 9]。

18、然后再随机从buffer取一个元素输出,第九次输出8;于是缓冲区变成了[ , 9];output的第九个元素是8。

19、最后,buffer区发现dataset区已经被他“掏空了”(可怜巴巴...),于是只好无可奈何用小于buffer size的剩下的元素进行随机输出;第十次输出啥呢?笨蛋,只能输出9 了。

20、于是output区的最后一个元素是9。

三、结合batch进一步理解shuffle

3.1 当reshuffle_each_iteration=None时

如果这个参数不设置,也即为None时,和它为True时效果竟然一样(俺也不知道为啥...不过既然如此,建议手动设置为False)。

如果这个参数为True,那么每一次迭代(比如for循环、list()/tuple()方法、iter()方法)的时候,都会重新乱序。

这样也就导致每一次的迭代输出,结果不一样。

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3)
print(list(dataset.as_numpy_iterator()))

dataset = dataset.batch(2)

print(list(dataset))

for ds in dataset:
    print(ds.numpy())
    
# [0, 2, 1, 4, 5, 3, 8, 7, 9, 6]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 3], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 2], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([7, 6], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([9, 8], dtype=int64)>]
# [1 2]
# [0 5]
# [4 3]
# [8 7]
# [9 6]

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3,reshuffle_each_iteration=True)
print(list(dataset.as_numpy_iterator()))

dataset = dataset.batch(2)

print(list(dataset))

for ds in dataset:
    print(ds.numpy())
    
# [1, 2, 4, 0, 3, 6, 7, 9, 5, 8]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 8], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([9, 2], dtype=int64)>]
# [1 3]
# [0 2]
# [5 4]
# [7 9]
# [6 8]

3.2 当reshuffle_each_iteration=False时

dataset = tf.data.Dataset.range(10)
dataset = dataset.shuffle(3,reshuffle_each_iteration=False)
print(list(dataset.as_numpy_iterator()))

dataset = dataset.batch(2)

print(list(dataset))

for ds in dataset:
    print(ds.numpy())


# 输出如下:
# [0, 2, 1, 5, 4, 7, 3, 9, 6, 8]
# <BatchDataset shapes: (None,), types: tf.int64>
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 2], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([1, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([3, 9], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 8], dtype=int64)>]
# [0 2]
# [1 5]
# [4 7]
# [3 9]
# [6 8]

3.3 batch本身是不会乱序的,以原序按批次分割

dataset = tf.data.Dataset.range(10)
# dataset = dataset.shuffle(3)
print(list(dataset.as_numpy_iterator()))

dataset = dataset.batch(2)

print(list(dataset))

for ds in dataset:
    print(ds.numpy())
    
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
# [<tf.Tensor: shape=(2,), dtype=int64, numpy=array([0, 1], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([2, 3], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([4, 5], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([6, 7], dtype=int64)>, <tf.Tensor: shape=(2,), dtype=int64, numpy=array([8, 9], dtype=int64)>]
# [0 1]
# [2 3]
# [4 5]
# [6 7]
# [8 9]

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值