tf读取数据的几种方式

本文介绍了TensorFlow中读取数据的三种方式:直接读取预加载数据、使用feed_dict和从文件中读取,重点讨论了从文件中读取的两种方法——队列管理器和Dataset API。通过队列管理器,数据在多线程中高效组织,最后一个epoch不满batch的数据会被丢弃。而Dataset API允许更灵活的数据处理,包括对不同长度数据的padding,并能实现可复用的迭代器,适用于train和valid的交替训练。两种方法都避免了手动管理线程队列,提供了高效的数据输入解决方案。
摘要由CSDN通过智能技术生成

1.最简单的方式

import tensorflow as tf

a = tf.zeros([2,3])
b = tf.ones([2,3])
c = tf.add(a, b)
with tf.Session() as sess:
    print(sess.run(b))

直接读取已经预加载在Graph中的数据,数据量大的时候,要把所有的数据都预加载,非常不合理

2.通过feed_dict

import numpy as np
import tensorflow as tf

x = np.reshape(np.arange(6), [2,3])
a = tf.zeros([2,3])
b = tf.placeholder(dtype=tf.float32, shape=[2,3])
c = tf.add(a, b)
with tf.Session() as sess:
    print(sess.run(b, feed_dict={
   b:x}))

也很简单,预先设置tf.placeholder即可

3.直接从文件中读取

主要是针对大数据,效率高

  • 通过tf.train.slice_input_producer,管理线程队列读取

原理讲解:https://zhuanlan.zhihu.com/p/27238630

import tensorflow as tf

x = [[1,2],[2,3],[4,5],[6,7],[7,8],[9,10],[11,12],[13,14]]
label = ["a","b","c","d","a","b","c","d"]
# 此处shuffle=True的话不需要tf.train.shuffle_batch,batch即可
input_queues = tf.train.slice_input_producer([x, label],shuffle=False,num_epochs=2) 
x, y = tf.train.batch(input_queues,
                          num_threads=8,
                          batch_size=3,
                          capacity= 128,
                          allow_smaller_final_batch=False)
with tf.Session() as sess:
    tf.local_variables_initializer().run()
    # 使用start_queue_runners之后,才会开始填充队列
    coord = tf.train.Coordinator()
    threads = tf
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值