加载数据
tensorflow作为符号编程框架,需要先构建数据流图,再读取数据,随后在进行模型的训练,所以其官网给出了三种加载数据的方式
- 预加载数据, 在tensorflow中,通过定义常量或者变脸来保存所有数据
- 填充数据 , 产生数据,再把数据填充后端
- 从文件中读取 从文件中直接读取,让队列管理器从文件中读取数据
(1)预加载数据
x1=tf.constant([2,3,4])
x2=tf.constant([4, 0, 1])
y=tf.add(x1,x2)
这种方法的缺点在于,将数据直接嵌入在数据流图中,当训练数据比较大的时候,很消耗内存
(2)填充数据
使用sess.run()中的feed_dict参数,将python产生的数据填充给后端
x1=tf.constant([2,3,4])
x2=tf.constant([4, 0, 1])
y=tf.add(x1,x2)
import tensorflow as tf
a1=tf.placeholder(tf.int16)
a2=tf.placeholder(tf.int16)
b=tf.add(x1,x2)
c1=[2,3,4]
c2=[4,5,6]
with tf.Session() as sess:
print(sess.run(b,feed_dict={a1:c1,a2:c2}))
填充的方式也有数据量大,消耗内存等缺点,并且数据类型转换等中间环节增减了不少开销,所以最好是使用第三种方法
(3)从文件中读取数据
从文件中读取数据分为两个步骤
- 把样本写入TFRecords二进制文件中
- 再从队列中读取