一、TensorFlow 数据的输入
- preloaded data : 预加载数据
- Feeding : pyhthon 产生数据,再把数据喂给后端
- Reading from file : 从文件中直接读取
Preload:
import tensorflow as tf
#define Graph
x1 = tf.constant([2,3,4])
x2 = tf.constant([4,0,1])
y = tf.add(x1,x2)
#define session
with tf.Session() as sess:
print sess.run(y)
Feeding:
import tensorflow as tf
#define Graph
x1 = tf.placeholder(tf.int16)
x2 = tf.placeholder(tf.int16)
y = tf.add(x1,x2)
# python generate data
li1 = [2,3,4]
li2 = [4,0,1]
#session
with tf.Session() as sess:
sess.run(y,feed_dict = {x1:li1,x2:li2})
Read from file :
一个典型的文件读取管线会包含下列的步骤:
- 文件名列表 [“file0”,”file1” …]
- 可配置的文件名乱序
- 可配置的最大迭代次数
- 文件名对列
- 针对输入文件格式的阅读器
- 记录的解析器
- 可配置的预处理器
- 样本队列
import tensorflow as tf
filename = os.path.join(os.getcwd(), file_name)
#产生文件队列, 可配置文件名和乱序
filename_queue = tf.train.string_input_producer([filenames],shuffle = True)
reader = tf.TextLineReader(skip_header_lines = 1)
#每一次 read 都会从文件中读取一行内容。
key , value = reader.read(filename_queue)
record_defaults = [[0],[0],[0],[0]]
#会解析这一行内容并将其转换为张量列表
decoded = tf.decode_csv(value , record_defaults = record_defaults)
with Session as sess:
#coordinate这是负责在收到任何关闭信号的时候让所有的线程都知道
coord = tf.train.Coordinator()
#在调用run或者eval去执行read之前,必须先调用一下将文件名填充到队列中,否则read将会堵塞
threads = tf.train.start_queue_runners(coord = coord)
批处理
在数据输入管线的末端,我们需要有另一个队列来执行输入样本的训练,评价和推理,因此我们使用
一下语句对队列中的样本进行乱序的处理。
#min_after_dequeue defines how big a buffer we will randomly sample
# from -- bigger means better shuffling but slower start up and more
# memory used.
# capacity must be larger than min_after_dequeue and the amount larger
# determines the maximum we will prefetch. Recommendation:
# min_after_dequeue + (num_threads + a small safety margin) * batch_size
min_after_dequeue = 10000
capacity = min_after_dequeue + 3*batch_size
tf.train.shuffle_batch(decoded , batch_size = batch_size , capacity = capcity ,min_after_dequeue = batch_size)