tf.train.shuffle_batch函数解析

本文详细介绍tf.train.shuffle_batch函数的功能及用法,该函数通过随机打乱张量顺序创建批次,适用于TensorFlow数据处理流程。文章解释了各参数的意义,如capacity、min_after_dequeue等,并给出示例代码。

tf.train.shuffle_batch函数解析

觉得有用的话,欢迎一起讨论相互学习~


我的微博我的github我的B站

tf.train.shuffle_batch

  • (tensor_list, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, name=None)

  • Creates batches by randomly shuffling tensors. 通过随机打乱张量的顺序创建批次.

简单来说就是读取一个文件并且加载一个张量中的batch_size行

This function adds the following to the current Graph:
这个函数将以下内容加入到现有的图中.

  • A shuffling queue into which tensors from tensor_list are enqueued.
    一个由传入张量组成的随机乱序队列

  • A dequeue_many operation to create batches from the queue.
    从张量队列中取出张量的出队操作

  • A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors
    from tensor_list.
    一个队列运行器管理出队操作.
    If enqueue_many is False, tensor_list is assumed to represent a single example. An input tensor with shape [x, y, z] will be output as a tensor with shape [batch_size, x, y, z].

  • If enqueue_many is True, tensor_list is assumed to represent a batch of examples, where the first dimension is indexed by example, and all members of tensor_list should have the same size in the first dimension. If an input tensor has shape [*, x, y, z], the output will have shape [batch_size, x, y, z].

enqueue_many主要是设置tensor中的数据是否能重复,如果想要实现同一个样本多次出现可以将其设置为:“True”,如果只想要其出现一次,也就是保持数据的唯一性,这时候我们将其设置为默认值:“False”

  • The capacity argument controls the how long the prefetching is allowed to grow the queues. capacity控制了预抓取操作对于增加队列长度操作的长度.

  • For example:

# Creates batches of 32 images and 32 labels.
image_batch, label_batch = tf.train.shuffle_batch( [single_image, single_label], batch_size=32, num_threads=4,capacity=50000,min_after_dequeue=10000)

这段代码写的是从[single_image, single_label]利用4个线程读取32个数据作为一个batch

Args:
  • tensor_list: The list of tensors to enqueue.
    入队的张量列表
  • batch_size: The new batch size pulled from the queue.
    表示进行一次批处理的tensors数量.
  • capacity: An integer. The maximum number of elements in the queue.

容量:一个整数,队列中的最大的元素数.
这个参数一定要比min_after_dequeue参数的值大,并且决定了我们可以进行预处理操作元素的最大值.
推荐其值为:
c a p a c i t y = ( m i n _ a f t e r _ d e q u e u e + ( n u m _ t h r e a d s + a   s m a l l   s a f e t y   m a r g i n ∗ b a t c h s i z e ) capacity=(min\_after\_dequeue+(num\_threads+a\ small\ safety\ margin*batch_size) capacity=(min_after_dequeue+(num_threads+a small safety marginbatchsize)

  • min_after_dequeue: Minimum number elements in the queue after a
    dequeue(出列), used to ensure a level of mixing of elements.
  • 当一次出列操作完成后,队列中元素的最小数量,往往用于定义元素的混合级别.
  • 定义了随机取样的缓冲区大小,此参数越大表示更大级别的混合但是会导致启动更加缓慢,并且会占用更多的内存
  • num_threads: The number of threads enqueuing tensor_list.
  • 设置num_threads的值大于1,使用多个线程在tensor_list中读取文件,这样保证了同一时刻只在一个文件中进行读取操作(但是读取速度依然优于单线程),而不是之前的同时读取多个文件,这种方案的优点是:
  1. 避免了两个不同的线程从同一文件中读取用一个样本
  2. 避免了过多的磁盘操作
  • seed: Seed for the random shuffling within the queue.
    打乱tensor队列的随机数种子
  • enqueue_many: Whether each tensor in tensor_list is a single example.
    定义tensor_list中的tensor是否冗余.
  • shapes: (Optional) The shapes for each example. Defaults to the
    inferred shapes for tensor_list.
    用于改变读取tensor的形状,默认情况下和直接读取的tensor的形状一致.
  • name: (Optional) A name for the operations.
Returns:
  • A list of tensors with the same number and types as tensor_list.
    默认返回一个和读取tensor_list数据和类型一个tensor列表.
### TensorFlow `string_input_producer` 替代方案 自版本1.12起,TensorFlow逐步弃用了部分队列机制相关的函数,其中包括`tf.train.string_input_producer`。为了替代这一功能,在现代TensorFlow实践中推荐采用`tf.data.Dataset` API来处理输入管道[^1]。 #### 使用 `tf.data.Dataset.from_tensor_slices` 对于原本依赖于`string_input_producer`读取文件列表的应用场景,可以利用`Dataset.from_tensor_slices()`方法创建数据集实例: ```python import tensorflow as tf filenames = ['file1.txt', 'file2.txt'] dataset = tf.data.Dataset.from_tensor_slices(filenames) def _parse_function(filename): image_string = tf.io.read_file(filename) # 解析图像等内容... return parsed_content dataset = dataset.map(_parse_function).shuffle(buffer_size=10000).batch(32) ``` 此代码片段展示了如何通过指定文件名数组构建数据集,并定义了解码操作 `_parse_function` 来加载和预处理单个样本。接着应用 `.map()` 方法转换每个元素;`.shuffle()` 和 `.batch()` 则用于打乱顺序并分批传输给模型训练过程。 #### 构建更复杂的数据流图 当面对更加复杂的输入需求时,比如需要循环遍历整个数据集多次或设置epoch数量,则可以通过组合多个API实现所需行为: ```python # 假设 filenames 是一个字符串张量组成的列表 dataset = tf.data.Dataset.from_tensor_slices(filenames) # 设置重复次数 (无限次),并在每次迭代结束时重新洗牌 dataset = dataset.repeat().shuffle(len(filenames)) iterator = iter(dataset) next_element = iterator.get_next() ``` 上述例子中,`repeat()` 函数使得数据集可以在一轮结束后继续提供新批次直到显式停止条件达成;而 `shuffle()` 可确保每轮获取到不同排列的记录集合。
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值