多线程输入数据处理框架

经典输入数据处理流程图:指定原始数据的文件列表--->创建文件列表队列-->从文件中读取数据-->数据预处理-->整理成batch作为神经网络的输入。

队列也是TF多线程输入数据处理框架的基础。队列和变量类似,都是计算图上有状态的节点。修改队列状态的操作主要有:Enqueue、EnqueueMany、Dequeue

TF中提供了FIFOQueue(先进先出)和RandomShuffleQueue(会将队列中的元素随机打乱)两种队列。

附上一个简单例子:

import tensorflow as tf
#需给出min_after_dequeue的值
q = tf.RandomShuffleQueue(capacity=4, dtypes="int32", min_after_dequeue=2)
#初始化,关于shape不太清楚
init = q.enqueue_many(([0, 10, 6, 8], ), name="q")
#出队
x = q.dequeue()
y = x + 1
#入队
q_inc = q.enqueue([y])

with tf.Session() as sess:
    sess.run(init)
    for i in range(10):
        v, _ =sess.run([x, q_inc])
        print(v)
TF提供了tf.Coordinator(协同多个线程一起停止)和tf.QueueRunner两个类来完成多线程协同的功能。给出一个使用tf.Coordinator的例子:
import tensorflow as tf
import numpy as np
import threading
import time

def MyLoop(coord, worker_id):
    #使用tf.Coordinator类提供的协同工具来判断当前线程是否需要停止
    while not coord.should_stop():
        if np.random.rand() < 0.1:
            print("Stopping from id:%d\n" % worker_id)
            coord.request_stop()
        else:
            print("Working on id: %d\n"% worker_id)

    time.sleep(1)
    #声明一个tf.train.Coordinator类来协同多个线程
coord = tf.train.Coordinator()
    #声明创建5个线程Thread(group=None, target=None, name=None, args=(), kwargs={})
    #group: 线程组,目前还没有实现,库引用中提示必须是None;
    #target: 要执行的方法;
    #name: 线程名;
    #args/kwargs: 要传入方法的参数。
threads = [threading.Thread(target=MyLoop, args=(coord, i,)) for i in range(5)]
#启动线程
for t in threads:
    t.start()
#等待所有线程退出
coord.join(threads)
tf.QueueRunner主要用于启动多个线程来操作同一个队列。tf.Coordinator和tf.QueueRunner联合使用的例子:

import tensorflow as tf
import numpy as np
queue = tf.FIFOQueue(capacity=100, dtypes="float")
enqueue_op = queue.enqueue([tf.random_normal([1])])
#init_op = queue.enqueue_many(([np.random.rand(100)]), )
#启动5个线程, 每个线程中运行enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op]*5)
tf.train.add_queue_runner(qr)
out_tensor = queue.dequeue()

with tf.Session() as sess:
    #sess.run(init_op)   不初始化也能出结果
    print(queue)
    coord = tf.train.Coordinator()
    #使用tf.train.QueueRunner时需要明确调用tf.train.start_queue_runners来启动所有线程
    #tf.train.start_queue_runners函数会默认启动tf.GraphKeys.QUEUE_RUNNER集合中的所有QueueRunner
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(3):
        print(sess.run(out_tensor))
        coord.request_stop()
        coord.join(threads)
输入文件队列:

tf.train.string_input_producer会把输入队列中的文件均匀地分给不同的线程,当一个输入队列中的所有文件被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列。可以通过设置num_epochs限制加载初始文件列表的最大轮数。

import tensorflow as tf
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

num_shards = 10
instances_per_shard = 5
for i in range(num_shards):
    #将数据分为多个文件时,可以将不同文件以类似0000n-of-0000m的后缀区分。
    #其中m表示数据总共被存在了m个文件中,n表示当前文件标号。
    filename = ('/home/cvx/Downloads/Data/data.tfrecords-%.5d-of-%.5d' % (i, num_shards))
    writer = tf.python_io.TFRecordWriter(filename)
    for j in range(instances_per_shard):
        example = tf.train.Example(features=tf.train.Features(feature={'i': _int64_feature(i),
                                                                       'j': _int64_feature(j)}))
        writer.write(example.SerializeToString())
    writer.close()

#使用tf.train.match_filenames_once函数获取文件列表
files = tf.train.match_filenames_once("/home/cvx/Downloads/Data/data.tfrecords-*")
#init_op = (tf.global_variables_initializer(), tf.local_variables_initializer()) -----需要初始化局部变量,否则会报错。
#with tf.Session() as sess:
    #sess.run(init_op)
    #print(sess.run(files)) ------此段代码可以查看files里的内容,下面也有类似的实现

#创建了输入队列,输入队列中文件列表为files
filename_queue = tf.train.string_input_producer(files, shuffle=False)
#解析样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
    'i': tf.FixedLenFeature([], tf.int64),
    'j': tf.FixedLenFeature([], tf.int64),
})
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    tf.local_variables_initializer().run()
    #print(sess.run(files))
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    for i in range(6):
        print(sess.run([features['i'], features['j']]))
    coord.request_stop()
    coord.join(threads)





  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值