TensorFlow之shuffle_batch批量读文件

tensorflow的官方文档中对

tf.train.shuffle_batch()
这个函数有如下的描写

这个函数的功能是:Creates batches by randomly shuffling tensors.

但需要注意的是它是一种图运算,要跑在sess.run()里

  • A shuffling queue into which tensors from tensors are enqueued.
  • 一个乱序的队列,进队的正是传入的tensors
  • A dequeue_many operation to create batches from the queue.
  • 一个dequeue_many的操作从队列中推出成batch的tensor
  • A QueueRunner to QUEUE_RUNNER collection, to enqueue the tensors from tensors.
  • 一个QueueRunner的线程,正是这个线程将传入的数据推进队列中.

把数据放在队列里有很多好处,可以完成训练数据和测试数据的解耦,同时有利于写成分布式训练(个人理解),但需要注意的是在取数据的时候,容易造成堵塞的情况.
这时候,应该需要截获超时异常来强制停止线程.

Coordinator
Coordinator类用来帮助多个线程协同工作,多个线程同步终止。 其主要方法有:

should_stop():如果线程应该停止则返回True。

request_stop(<exception>): 请求该线程停止。

join(<list of threads>):等待被指定的线程终止。

首先创建一个Coordinator对象,然后建立一些使用Coordinator对象的线程。这些线程通常一直循环运行,一直到should_stop()返回True时停止。 任何线程都可以决定计算什么时候应该停止。它只需要调用request_stop(),同时其他线程的should_stop()将会返回True,然后都停下来。


This function adds the following to the current Graph:

#!/usr/bin/env python
# -*-coding: utf-8-*-
import tensorflow as tf

COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
           "marital_status", "occupation", "relationship", "race", "gender",
           "capital_gain", "capital_loss", "hours_per_week", "native_country",
           "income_bracket", "label"]

def read_file_format(filename_queue):
    """
    读文件数据, 执行一次读一行
    :param filename_queue:
    :return:
    """
    # 定义Reader
    reader = tf.TextLineReader()
    # Each execution of read reads a single line from the file. 执行一次读一行
    key, record_string = reader.read(filename_queue)

    # 定义Decoder
    # record_defaults:空值补的默认值
    record_defaults = [[''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], [''], ['']]
    cols = tf.decode_csv(record_string, record_defaults=record_defaults)
    features = zip(COLUMNS, cols)
    label = 1
    return features, label


def input_pipeline(filenames, batch_size, num_epochs=None):
    """
    输入管道,使用队列(queue)读取数据
    :param filenames:
    :param batch_size: 一批样本数
    :param num_epochs:
    :return:
    """
    # creates a FIFO queue for holding the filenames until the reader needs them
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epochs, shuffle=True)

    single_image, single_label = read_file_format(filename_queue)
    # min_after_dequeue defines how big a buffer we will randomly sample
    #   from -- bigger means better shuffling but slower start up and more
    #   memory used.
    # capacity must be larger than min_after_dequeue and the amount larger
    #   determines the maximum we will prefetch.  Recommendation:
    #   min_after_dequeue + (num_threads + a small safety margin) * batch_size

    # Creates batches by randomly shuffling tensors
    # Creates batches of batch_size images and batch_size labels.
    min_after_dequeue = 10000
    capacity = min_after_dequeue + 3 * batch_size
    example_batch, label_batch = tf.train.shuffle_batch(
        [single_image, single_label], batch_size=batch_size, capacity=capacity,
        min_after_dequeue=min_after_dequeue)

    return example_batch, label_batch


train_file_name = "./data/adult.data"
test_file_name = "./data/adult.test"

example_batch, label_batch = input_pipeline([train_file_name, test_file_name], batch_size=10)

with tf.Session() as sess:
    # Start populating the filename queue.
    # Coordinator类用来帮助多个线程协同工作,多个线程同步终止
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    train_steps = 10
    # Retrieve a single instance:
    try:
        while not coord.should_stop():  # 如果线程应该停止则返回True
            example, label = sess.run([example_batch, label_batch])
            print (example)

            train_steps -= 1
            print train_steps
            if train_steps <= 0:
                coord.request_stop()    # 请求该线程停止

    except tf.errors.OutOfRangeError:
        print ('Done training -- epoch limit reached')
    finally:
        # When done, ask the threads to stop. 请求该线程停止
        coord.request_stop()
        # And wait for them to actually do it. 等待被指定的线程终止
        coord.join(threads)
        
sess.close()






  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值