1.最简单的方式
import tensorflow as tf
a = tf.zeros([2,3])
b = tf.ones([2,3])
c = tf.add(a, b)
with tf.Session() as sess:
print(sess.run(b))
直接读取已经预加载在Graph中的数据,数据量大的时候,要把所有的数据都预加载,非常不合理
2.通过feed_dict
import numpy as np
import tensorflow as tf
x = np.reshape(np.arange(6), [2,3])
a = tf.zeros([2,3])
b = tf.placeholder(dtype=tf.float32, shape=[2,3])
c = tf.add(a, b)
with tf.Session() as sess:
print(sess.run(b, feed_dict={
b:x}))
也很简单,预先设置
tf.placeholder
即可
3.直接从文件中读取
主要是针对大数据,效率高
- 通过
tf.train.slice_input_producer
,管理线程队列读取
原理讲解:https://zhuanlan.zhihu.com/p/27238630
import tensorflow as tf
x = [[1,2],[2,3],[4,5],[6,7],[7,8],[9,10],[11,12],[13,14]]
label = ["a","b","c","d","a","b","c","d"]
# 此处shuffle=True的话不需要tf.train.shuffle_batch,batch即可
input_queues = tf.train.slice_input_producer([x, label],shuffle=False,num_epochs=2)
x, y = tf.train.batch(input_queues,
num_threads=8,
batch_size=3,
capacity= 128,
allow_smaller_final_batch=False)
with tf.Session() as sess:
tf.local_variables_initializer().run()
# 使用start_queue_runners之后,才会开始填充队列
coord = tf.train.Coordinator()
threads = tf