语法
batch(batch_size, drop_remainder=False, num_parallel_calls=None, deterministic=None,name=None)
该函数可以将此数据集的连续元素合并到batch中。
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
list(dataset.as_numpy_iterator())
# [array([0, 1, 2], dtype=int64), array([3, 4, 5], dtype=int64)]
函数返回值将有一个额外的外部维度,即batch_size
。如果batch_size
未将输入元素的数量
N
N
N平均分割,且drop_remainder
为False
,则最后一个元素的batch_size
为N % batch_size
。如果需要依赖于具有相同尺寸的batch,则应将drop_rements
参数设置为True
,以防止生成较小的批。如果程序要求数据具有静态已知形状,则应使用drop_rements=True
。如果没有drop_rements=True
,则输出数据集的形状将具有未知的前导维度,因为最终批次可能更小。
参数
参数 | 意义 |
---|---|
batch_size | [tf.int64 /tf.Tensor ]表示要在单个批次中组合的此数据集的连续元素数。 |
drop_remainder | [可选, tf.bool /tf.Tensor ]表示如果最后一批元素少于批次大小,是否应删除最后一批元素,默认为False 。 |
num_parallel_calls | [可选,tf.int64 /tf.Tensor ]表示异步并行计算的批数。如果未指定,将按顺序计算批次。如果值为tf.data.AUTOTUNE 被使用,则根据可用资源动态设置并行调用的数量。 |
deterministic | [可选]当num_parallel_calls 被指定时,如果指定了此布尔值(True 或False ),则它控制转换生成元素的顺序。如果设置为False ,则允许转换产生无序的元素,即损失性能的情况下换取确定性。如果未指定,则tf.data.Options.deterministic (默认为True )来控制行为。 |
name | [可选]tf.data 操作的名称 |
返回值
返回值 | 意义 |
---|---|
Dataset | 一个tf.data.Dataset 的数据集。 |
函数实现
def batch(self,
batch_size,
drop_remainder=False,
num_parallel_calls=None,
deterministic=None,
name=None):
"""Combines consecutive elements of this dataset into batches.
>>> dataset = tf.data.Dataset.range(8)
>>> dataset = dataset.batch(3)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5]), array([6, 7])]
>>> dataset = tf.data.Dataset.range(8)
>>> dataset = dataset.batch(3, drop_remainder=True)
>>> list(dataset.as_numpy_iterator())
[array([0, 1, 2]), array([3, 4, 5])]
The components of the resulting element will have an additional outer
dimension, which will be `batch_size` (or `N % batch_size` for the last
element if `batch_size` does not divide the number of input elements `N`
evenly and `drop_remainder` is `False`). If your program depends on the
batches having the same outer dimension, you should set the `drop_remainder`
argument to `True` to prevent the smaller batch from being produced.
Note: If your program requires data to have a statically known shape (e.g.,
when using XLA), you should use `drop_remainder=True`. Without
`drop_remainder=True` the shape of the output dataset will have an unknown
leading dimension due to the possibility of a smaller final batch.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
whether the last batch should be dropped in the case it has fewer than
`batch_size` elements; the default behavior is not to drop the smaller
batch.
num_parallel_calls: (Optional.) A `tf.int64` scalar `tf.Tensor`,
representing the number of batches to compute asynchronously in
parallel.
If not specified, batches will be computed sequentially. If the value
`tf.data.AUTOTUNE` is used, then the number of parallel
calls is set dynamically based on available resources.
deterministic: (Optional.) When `num_parallel_calls` is specified, if this
boolean is specified (`True` or `False`), it controls the order in which
the transformation produces elements. If set to `False`, the
transformation is allowed to yield elements out of order to trade
determinism for performance. If not specified, the
`tf.data.Options.deterministic` option (`True` by default) controls the
behavior.
name: (Optional.) A name for the tf.data operation.
Returns:
Dataset: A `Dataset`.
"""
if num_parallel_calls is None or DEBUG_MODE:
if deterministic is not None and not DEBUG_MODE:
warnings.warn("The `deterministic` argument has no effect unless the "
"`num_parallel_calls` argument is specified.")
return BatchDataset(self, batch_size, drop_remainder, name=name)
else:
return ParallelBatchDataset(
self,
batch_size,
drop_remainder,
num_parallel_calls,
deterministic,
name=name)