tensorflow中tf.train.shuffle_batch函数

tf.train.shuffle_batch文档内容如下:

Help on function shuffle_batch in module tensorflow.python.training.input:

shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, enqueue_many=False, shapes=None, allow_smaller_final_batch=False, shared_name=None, name=None)
    Creates batches by randomly shuffling tensors. (deprecated)
    
    Warning: THIS FUNCTION IS DEPRECATED. It will be removed in a future version.
    Instructions for updating:
    Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.shuffle(min_after_dequeue).batch(batch_size)`.
    
    This function adds the following to the current `Graph`:
    
    * A shuffling queue into which tensors from `tensors` are enqueued.
    * A `dequeue_many` operation to create batches from the queue.
    * A `QueueRunner` to `QUEUE_RUNNER` collection, to enqueue the tensors
      from `tensors`.
    
    If `enqueue_many` is `False`, `tensors` is assumed to represent a
    single example.  An input tensor with shape `[x, y, z]` will be output
    as a tensor with shape `[batch_size, x, y, z]`.
    
    If `enqueue_many` is `True`, `tensors` is assumed to represent a
    batch of examples, where the first dimension is indexed by example,
    and all members of `tensors` should have the same size in the
    first dimension.  If an input tensor has shape `[*, x, y, z]`, the
    output will have shape `[batch_size, x, y, z]`.
    
    The `capacity` argument controls the how long the prefetching is allowed to
    grow the queues.
    
    The returned operation is a dequeue operation and will throw
    `tf.errors.OutOfRangeError` if the input queue is exhausted. If this
    operation is feeding another input queue, its queue runner will catch
    this exception, however, if this operation is used in your main thread
    you are responsible for catching this yourself.
    
    For example:
    
    ```python
    # Creates batches of 32 images and 32 labels.
    image_batch, label_batch = tf.compat.v1.train.shuffle_batch(
          [single_image, single_label],
          batch_size=32,
          num_threads=4,
          capacity=50000,
          min_after_dequeue=10000)
    ```
    
    *N.B.:* You must ensure that either (i) the `shapes` argument is
    passed, or (ii) all of the tensors in `tensors` must have
    fully-defined shapes. `ValueError` will be raised if neither of
    these conditions holds.
    
    If `allow_smaller_final_batch` is `True`, a smaller batch value than
    `batch_size` is returned when the queue is closed and there are not enough
    elements to fill the batch, otherwise the pending elements are discarded.
    In addition, all output tensors' static shapes, as accessed via the
    `shape` property will have a first `Dimension` value of `None`, and
    operations that depend on fixed batch_size would fail.
    
    Args:
      tensors: The list or dictionary of tensors to enqueue.
      batch_size: The new batch size pulled from the queue.
      capacity: An integer. The maximum number of elements in the queue.
      min_after_dequeue: Minimum number elements in the queue after a
        dequeue, used to ensure a level of mixing of elements.
      num_threads: The number of threads enqueuing `tensor_list`.
      seed: Seed for the random shuffling within the queue.
      enqueue_many: Whether each tensor in `tensor_list` is a single example.
      shapes: (Optional) The shapes for each example.  Defaults to the
        inferred shapes for `tensor_list`.
      allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
        batch to be smaller if there are insufficient items left in the queue.
      shared_name: (Optional) If set, this queue will be shared under the given
        name across multiple sessions.
      name: (Optional) A name for the operations.
    
    Returns:
      A list or dictionary of tensors with the types as `tensors`.
    
    Raises:
      ValueError: If the `shapes` are not specified, and cannot be
        inferred from the elements of `tensors`.

具体来说就是输入一个tensor类型的队列,打乱之后输出。具体参数解释如下:
tensors:排列的张量或词典.
batch_size:从队列中提取新的批量大小.
capacity:队列中元素的最大数量.
min_after_dequeue:出队后队列中元素的最小数量,用于确保元素的混合级别.
num_threads:线程数量.
seed:队列内随机乱序的种子值.
enqueue_many:tensors中的张量是否都是一个例子.
shapes:每个示例的形状.(可选项)
allow_smaller_final_batch:为True时,若队列中没有足够的项目,则允许最终批次更小.(可选项)
shared_name:如果设置,则队列将在多个会话中以给定名称共享.(可选项)
name:操作的名称.(可选项)

三个主要参数batch_size,capacity,min_after_dequeue。这三个数据决定了怎么输出数据,数据有多乱。record输出是按照队列输出的,那么我们一次从tfrecord加多少容量的队列呢?这个参数就是capacity。注意,此时首先加载一定量(capacity)的数据是按照tfrecord顺序的。在加载一定量capacity的数据后,才进行打乱,出队列需要的数据(也就是出队出batch size量的数据供我们的模型训练加载)。而min_after_dequeue代表了该capacity容量下的队列中元素的最小数量,用于确保元素的混合程度,也就是说,该队列在不断的输出数据后,其容量不能低于min_after_dequeue的数值,若低于就应该又从tfrecord中加载数据入队,那很显然的,若min_after_dequeue的值较大接近capacity的话,那么其实就是输出数据后,队列值不满足最小量,又从tfrecord加了新的数据,就增加了其混合程度。

用法如下:

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np

images = np.random.random([5,2])
label = np.asarray(range(0, 5))
images = tf.cast(images, tf.float32)
label = tf.cast(label, tf.int32)
input_queue = tf.train.slice_input_producer([images, label], shuffle=False)
# 将队列中数据打乱后再读取出来
image_batch, label_batch = tf.train.shuffle_batch(input_queue, batch_size=10, num_threads=1, capacity=64, min_after_dequeue=1)

with tf.Session() as sess:
    # 线程的协调器
    coord = tf.train.Coordinator()
    # 开始在图表中收集队列运行器
    threads = tf.train.start_queue_runners(sess, coord)
    image_batch_v, label_batch_v = sess.run([image_batch, label_batch])
    for j in range(5):
        # print(image_batch_v.shape, label_batch_v[j])
        print(image_batch_v[j]),
        print(label_batch_v[j])
    # 请求线程结束
    coord.request_stop()
    # 等待线程终止
    coord.join(threads)

另外一个完整的使用程序如下:

import os
import tensorflow as tf 
from PIL import Image
from nets import nets_factory
import numpy as np
# 不同字符数量
CHAR_SET_LEN = 10
# 图片高度
IMAGE_HEIGHT = 60 
# 图片宽度
IMAGE_WIDTH = 160  
# 批次
BATCH_SIZE = 25
# tfrecord文件存放路径
TFRECORD_FILE = "D:/ddd-ss/tf_awei/Verification code/captcha/train.tfrecords"

# placeholder
x = tf.placeholder(tf.float32, [None, 224, 224])  
y0 = tf.placeholder(tf.float32, [None]) 
y1 = tf.placeholder(tf.float32, [None]) 
y2 = tf.placeholder(tf.float32, [None]) 
y3 = tf.placeholder(tf.float32, [None])

# 学习率
lr = tf.Variable(0.003, dtype=tf.float32)

# 从tfrecord读出数据
def read_and_decode(filename):
    # 根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])  #输入:1-Dstring类型的,输出:字符型输出队列
    reader = tf.TFRecordReader()  #给这个类赋个对象,  A Reader that outputs the records from a TFRecords file.
    # 返回文件名和文件
    bbb, serialized_example = reader.read(filename_queue)     #bbb文件名,serialized_example文件,(返回一个Key和value,可以按照字典里面的key与value来理解)
    features = tf.parse_single_example(serialized_example,    #解析单个“示例”原型,我也没看懂啥意思,记得有这种用法
                                       features={
                                           'image' : tf.FixedLenFeature([], tf.string),
                                           'label0': tf.FixedLenFeature([], tf.int64),
                                           'label1': tf.FixedLenFeature([], tf.int64),
                                           'label2': tf.FixedLenFeature([], tf.int64),
                                           'label3': tf.FixedLenFeature([], tf.int64),
                                       })      #返回一个字典,值都是张量  A `dict` mapping feature keys to `Tensor` and `SparseTensor` values.
    # 获取图片数据
    image = tf.decode_raw(features['image'], tf.uint8)   #将原始字节字符串转换为张量,
    # tf.train.shuffle_batch必须确定shape
    image = tf.reshape(image, [224, 224])
    # 图片预处理
    image = tf.cast(image, tf.float32) / 255.0   #类型转化,由tf.uint8转成tf.float32
    image = tf.subtract(image, 0.5)
    image = tf.multiply(image, 2.0)
    # 获取label
    label0 = tf.cast(features['label0'], tf.int32)
    label1 = tf.cast(features['label1'], tf.int32)
    label2 = tf.cast(features['label2'], tf.int32)
    label3 = tf.cast(features['label3'], tf.int32)

    return image, label0, label1, label2, label3       #返回的值都是张量类型的


# In[3]:

# 获取图片数据和标签
image, label0, label1, label2, label3 = read_and_decode(TFRECORD_FILE)

#使用shuffle_batch可以随机打乱
image_batch, label_batch0, label_batch1, label_batch2, label_batch3 = tf.train.shuffle_batch(   #????????
        [image, label0, label1, label2, label3], batch_size = BATCH_SIZE,
        capacity = 50000, min_after_dequeue=10000, num_threads=1)

#定义网络结构
train_network_fn = nets_factory.get_network_fn(
    'alexnet_v2',
    num_classes=CHAR_SET_LEN,
    weight_decay=0.0005,
    is_training=True)
 
    
with tf.Session() as sess:
    # inputs: a tensor of size [batch_size, height, width, channels]
    X = tf.reshape(x, [BATCH_SIZE, 224, 224, 1])
    # 数据输入网络得到输出值
    logits0,logits1,logits2,logits3,end_points = train_network_fn(X)
    
    # 把标签转成one_hot的形式
    one_hot_labels0 = tf.one_hot(indices=tf.cast(y0, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels1 = tf.one_hot(indices=tf.cast(y1, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels2 = tf.one_hot(indices=tf.cast(y2, tf.int32), depth=CHAR_SET_LEN)
    one_hot_labels3 = tf.one_hot(indices=tf.cast(y3, tf.int32), depth=CHAR_SET_LEN)
    
    # 计算loss
    loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits0,labels=one_hot_labels0)) 
    loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits1,labels=one_hot_labels1)) 
    loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits2,labels=one_hot_labels2)) 
    loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits3,labels=one_hot_labels3)) 
    # 计算总的loss
    total_loss = (loss0+loss1+loss2+loss3)/4.0
    # 优化total_loss
    optimizer = tf.train.AdamOptimizer(learning_rate=lr).minimize(total_loss) 
    
    # 计算准确率
    correct_prediction0 = tf.equal(tf.argmax(one_hot_labels0,1),tf.argmax(logits0,1))
    accuracy0 = tf.reduce_mean(tf.cast(correct_prediction0,tf.float32))
    
    correct_prediction1 = tf.equal(tf.argmax(one_hot_labels1,1),tf.argmax(logits1,1))
    accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1,tf.float32))
    
    correct_prediction2 = tf.equal(tf.argmax(one_hot_labels2,1),tf.argmax(logits2,1))
    accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2,tf.float32))
    
    correct_prediction3 = tf.equal(tf.argmax(one_hot_labels3,1),tf.argmax(logits3,1))
    accuracy3 = tf.reduce_mean(tf.cast(correct_prediction3,tf.float32)) 
    
    # 用于保存模型
    saver = tf.train.Saver()
    # 初始化
    sess.run(tf.global_variables_initializer())
    
    # 创建一个协调器,管理线程
    coord = tf.train.Coordinator()
    # 启动QueueRunner, 此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(6001):
        # 获取一个批次的数据和标签
        b_image, b_label0, b_label1 ,b_label2 ,b_label3 = sess.run([image_batch, label_batch0, label_batch1, label_batch2, label_batch3])
        # 优化模型
        sess.run(optimizer, feed_dict={x: b_image, y0:b_label0, y1: b_label1, y2: b_label2, y3: b_label3})  

        # 每迭代20次计算一次loss和准确率  
        if i % 20 == 0:  
            # 每迭代2000次降低一次学习率
            if i%2000 == 0:
                sess.run(tf.assign(lr, lr/3))
            acc0,acc1,acc2,acc3,loss_ = sess.run([accuracy0,accuracy1,accuracy2,accuracy3,total_loss],feed_dict={x: b_image,
                                                                                                                y0: b_label0,
                                                                                                                y1: b_label1,
                                                                                                                y2: b_label2,
                                                                                                                y3: b_label3})      
            learning_rate = sess.run(lr)
            print ("Iter:%d  Loss:%.3f  Accuracy:%.2f,%.2f,%.2f,%.2f  Learning_rate:%.4f" % (i,loss_,acc0,acc1,acc2,acc3,learning_rate))
             
            # 保存模型
            # if acc0 > 0.90 and acc1 > 0.90 and acc2 > 0.90 and acc3 > 0.90: 
            if i==6000:
                saver.save(sess, "./captcha/models/crack_captcha.model", global_step=i)  
                break 
                
    # 通知其他线程关闭
    coord.request_stop()
    # 其他所有线程关闭之后,这一函数才能返回
    coord.join(threads)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值