在TensorFlow中使用pipeline加载数据

正文共2028个字,6张图,预计阅读时间6分钟。


前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据。数据流如下图所示:



首先,A、B、C三个文件通过RandomShuffle进程被随机加载到FilenameQueue里,然后Reader1和Reader2进程同FilenameQueue里取文件名读取文件,读取的内容再被放到ExampleQueue里。最后,计算进程会从ExampleQueue里取数据。各个进程独立操作,互不影响,这样可以加快程序速度。


我们简单地生成3个样本文件。


#生成三个样本文件,每个文件包含5列,假设前4列为特征,最后1列为标签

data = np.zeros([20,5]) np.savetxt('file0.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file1.csv', data, fmt='%d', delimiter=',') data += 1np.savetxt('file2.csv', data, fmt='%d', delimiter=',')


然后,创建pipeline数据流。


#定义FilenameQueuefilename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)])

#定义ExampleQueue

example_queue = tf.RandomShuffleQueue(    capacity=1000,    min_after_dequeue=0,    dtypes=[tf.int32,tf.int32],    shapes=[[4],[1]] )

#读取CSV文件,每次读一行

reader = tf.TextLineReader() key, value = reader.read(filename_queue)

#对一行数据进行解码

record_defaults = [[1], [1], [1], [1], [1]] col1, col2, col3, col4, col5 = tf.decode_csv(    value, record_defaults=record_defaults) features = tf.stack([col1, col2, col3, col4])

#将特征和标签push进ExampleQueue

enq_op = example_queue.enqueue([features, [col5]])

#使用QueueRunner创建两个进程加载数据到ExampleQueue

qr = tf.train.QueueRunner(example_queue, [enq_op]*2)

#使用此方法方便后面tf.train.start_queue_runner统一开始进程

tf.train.add_queue_runner(qr) xs = example_queue.dequeue()

with tf.Session() as sess:    coord = tf.train.Coordinator()

#开始所有进程    threads = tf.train.start_queue_runners(coord=coord)    

for i in range(200):        x = sess.run(xs)        print(x)    coord.request_stop()    coord.join(threads)



以上我们采用for循环step_num次来控制训练迭代次数。我们也可以通过tf.train.string_input_producer的num_epochs参数来设置FilenameQueue循环次数来控制训练,当达到num_epochs时,TensorFlow会抛出OutOfRangeError异常,通过捕获该异常,停止训练。


filename_queue = tf.train.string_input_producer(["file%d.csv"%i for i in range(3)], num_epochs=6) ...

with tf.Session() as sess:    sess.run(tf.initialize_local_variables()) #必须加上这句话,否则报错!    coord = tf.train.Coordinator()

#开始所有进程

   threads = tf.train.start_queue_runners(coord=coord)    

try:        

while not coord.should_stop():            x = sess.run(xs)            print(x)    

except tf.errors.OutOfRangeError:        print('Done training -- epch limit reached')    

finally:        coord.request_stop()


捕获到异常时,请求结束所有进程。



原文: 在TensorFlow中使用pipeline加载数据(https://goo.gl/jbVPjM)


原文链接:https://www.jianshu.com/p/12b52e54a63c


查阅更为简洁方便的分类文章以及最新的课程、产品信息,请移步至全新呈现的“LeadAI学院官网”:

www.leadai.org


请关注人工智能LeadAI公众号,查看更多专业文章

大家都在看

LSTM模型在问答系统中的应用

基于TensorFlow的神经网络解决用户流失概览问题

最全常见算法工程师面试题目整理(一)

最全常见算法工程师面试题目整理(二)

TensorFlow从1到2 | 第三章 深度学习革命的开端:卷积神经网络

装饰器 | Python高级编程

今天不如来复习下Python基础


  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值