函数
shuffle(buffer_size, seed=None, reshuffle_each_iteration=None, name=None)
该函数可以随机洗牌此数据集的元素。此数据集使用buffer_size
的元素填充缓冲区,然后从该缓冲区中随机采样元素,用新元素替换所选元素。为了实现完美的洗牌,需要缓冲区大小大于或等于数据集的完整大小。
例如,如果您的数据集包含10000个元素,但buffer_size
设置为1000,则shuffle最初将仅从缓冲区中的前1000个元素中选择一个随机元素。一旦选择一个元素,其在缓冲区中的空间将被下一个(即1001个)元素替换,从而保持1000个元素的缓冲区。而reshuffle_each_iteration
控制每次迭代的洗牌顺序是否应该不同。
在TensorFlow2.X
中,tf.data.Dataset
对象是Python的iterables
,所以我们也可以用Python的循环遍历:
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
参数
参数 | 意义 |
---|---|
buffer_size | [tf.int64 /tf.Tensor ]表示新数据集将从此数据集中采样的元素数。 |
seed | [可选,tf.int64 /tf.Tensor ]表示将用于创建分布的随机种子。 |
reshuffle_each_iteration | [可选,tf.bool ]如果为True ,则表示每次迭代数据集时都应伪随机地重新洗牌,默认为True 。 |
name | [可选]tf.data 操作的名称 |
返回值
返回值 | 意义 |
---|---|
Dataset | 一个tf.data.Dataset 的数据集。 |
函数实现
def shuffle(self,
buffer_size,
seed=None,
reshuffle_each_iteration=None,
name=None):
"""Randomly shuffles the elements of this dataset.
This dataset fills a buffer with `buffer_size` elements, then randomly
samples elements from this buffer, replacing the selected elements with new
elements. For perfect shuffling, a buffer size greater than or equal to the
full size of the dataset is required.
For instance, if your dataset contains 10,000 elements but `buffer_size` is
set to 1,000, then `shuffle` will initially select a random element from
only the first 1,000 elements in the buffer. Once an element is selected,
its space in the buffer is replaced by the next (i.e. 1,001-st) element,
maintaining the 1,000 element buffer.
`reshuffle_each_iteration` controls whether the shuffle order should be
different for each epoch. In TF 1.X, the idiomatic way to create epochs
was through the `repeat` transformation:
```python
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 2, 0]
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
dataset = dataset.repeat(2)
# [1, 0, 2, 1, 0, 2]
```
In TF 2.0, `tf.data.Dataset` objects are Python iterables which makes it
possible to also create epochs through Python iteration:
```python
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
```
```python
dataset = tf.data.Dataset.range(3)
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]
```
Args:
buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
elements from this dataset from which the new dataset will sample.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the random
seed that will be used to create the distribution. See
`tf.random.set_seed` for behavior.
reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
that the dataset should be pseudorandomly reshuffled each time it is
iterated over. (Defaults to `True`.)
name: (Optional.) A name for the tf.data operation.
Returns:
Dataset: A `Dataset`.
"""
return ShuffleDataset(
self, buffer_size, seed, reshuffle_each_iteration, name=name)