TensorFlow学习--tf.train.batch与tf.train.shuffle_batch

tf.train.batch与tf.train.shuffle_batch的作用都是从队列中读取数据.

tf.train.batch

tf.train.batch() 按顺序读取队列中的数据
队列中的数据始终是一个有序的队列.队头一直按顺序补充,队尾一直按顺序出队.
参数:

  • tensors:排列的张量或词典.
  • batch_size:从队列中提取新的批量大小.
  • num_threads:线程数量.若批次是不确定 num_threads > 1.
  • capacity:队列中元素的最大数量.
  • enqueue_many:tensors中的张量是否都是一个例子.
  • shapes:每个示例的形状.(可选项)
  • dynamic_pad:在输入形状中允许可变尺寸.
  • allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)
  • shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)
  • name:操作的名称.(可选项)

若enqueue_many为False,则认为tensors代表一个示例.输入张量形状为[x, y, z]时,则输出张量形状为[batch_size, x, y, z].
若enqueue_many为True,则认为tensors代表一批示例,其中第一个维度为示例的索引,并且所有成员tensors在第一维中应具有相同大小.若输入张量形状为[*, x, y, z],则输出张量的形状为[batch_size, x, y, z].

tf.train.batch()示例
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np

images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
# 切片
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 按顺序读取队列中的数据
image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64)

with tf.Session() as sess:
    # 线程的协调器
    coord = tf.train.Coordinator()
    # 开始在图表中收集队列运行器
    threads = tf.train.start_queue_runners(sess, coord)
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    for j in range(5):
        print(image_batch_v[j]),
        print(label_batch_v[j])
    # 请求线程结束
    coord.request_stop()
    # 等待线程终止
    coord.join(threads)

按顺序读取队列中的数据,输出:

[ 0.05013787  0.53446019] 0
[ 0.91189879  0.69153142] 1
[ 0.39966023  0.86109054] 2
[ 0.85078746  0.05766034] 3
[ 0.71261722  0.60514599] 4

tf.train.shuffle_batch

tf.train.shuffle_batch() 将队列中数据打乱后再读取出来.
函数是先将队列中数据打乱,然后再从队列里读取出来,因此队列中剩下的数据也是乱序的.

  • tensors:排列的张量或词典.
  • batch_size:从队列中提取新的批量大小.
  • capacity:队列中元素的最大数量.
  • min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别.
  • num_threads:线程数量.
  • seed:队列内随机乱序的种子值.
  • enqueue_many:tensors中的张量是否都是一个例子.
  • shapes:每个示例的形状.(可选项)
  • allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)
  • shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)
  • name:操作的名称.(可选项)

其他与tf.train.batch()类似.

tf.train.shuffle_batch示例
#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np

images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将队列中数据打乱后再读取出来
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)

with tf.Session() as sess:
    # 线程的协调器
    coord = tf.train.Coordinator()
    # 开始在图表中收集队列运行器
    threads = tf.train.start_queue_runners(sess, coord)
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    for j in range(5):
        # print(image_batch_v.shape, label_batch_v[j])
        print(image_batch_v[j]),
        print(label_batch_v[j])
    # 请求线程结束
    coord.request_stop()
    # 等待线程终止
    coord.join(threads)

将队列中数据打乱后再读取出来,输出:

[ 0.08383977  0.75228119] 1
[ 0.03610427  0.53876138] 0
[ 0.33962703  0.47629601] 3
[ 0.21824744  0.84182823] 4
[ 0.8376292   0.52254623] 2
  • 8
    点赞
  • 53
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值