在使用tensorflow编程时,如何将输入数据传入计算图中是一个需要重点关注的问题,但是tensorflow中提供了库函数,将输入进行了封装,而我们只需要调用函数接口即可。主要的库函数在tensorflow.data.Dataset中。
1.输入数据确定
import tensorflow as tf
x=np.array([[1],[2],[3],[4]])
y=np.array([[1],[2],[3],[4]])
x_placeholder=tf.placeholder(dtype=tf.int32)
y_placeholder=tf.placeholder(dtype=tf.int32)
dataset=tf.data.Dataset.from_tensor_slices((x,y))
def func(x,y):
return x*1,y
dataset=dataset.map(func)
dataset=dataset.shuffle(2)
dataset=dataset.repeat()
dataset=dataset.batch(3)
iterator=dataset.make_initializable_iterator()
result_x,result_y=iterator.get_next()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
# sess.run(iterator.initializer,feed_dict={x_placeholder:x,y_placeholder:y})
sess.run(iterator.initializer)
for _ in range(1):
result =sess.run([result_x,result_y])
print(result)
result=sess.run([result_x,result_y])
print(result)
下面将解释每个函数的用法。
1) from_tensor_slices(tensors)
根据输入的tensors创建dataset
2)repeat(count=None)
表示数据集循环遍历的次数,None表示无限循环,count=epoch表示遍历数据集epoch次
3)shuffle(buffer_size,seed=None,reshuffle_each_iteration=None)
将数据集打乱,buffer_size意味着创建一个大小为buffer_size的缓冲区或者队列,每次从dataset中读取buffer_size个数据,并将缓冲区或者队列里的数据打乱再输出
4)batch(batch_size,drop_remainder=False)
从缓冲区中取出batch_size个数据
5)make_initializable_iterator(shared_name=None)
创建迭代器,通过迭代器来获取一个batch的数据
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)
这里需要注意的是,在计算图上进行计算时,必须要执行迭代器的初始化!!!
6)get_next(name=None)
通过调用迭代器的get_next函数来得到一个batch的数据
7)map(map_func,num_parallel_calls=None)
map函数实现的功能是对输入数据进行处理,可以自行实现map_func,map_func函数的参数就是dataset的切片,输入参数个数和返回参数个数必须相等
2.输入数据不确定
如果输入数据不确定,则需要使用placehoder占位符。
import tensorflow as tf
x=np.array([[1],[2],[3],[4]])
y=np.array([[1],[2],[3],[4]])
x_placeholder=tf.placeholder(dtype=tf.int32)
y_placeholder=tf.placeholder(dtype=tf.int32)
dataset=tf.data.Dataset.from_tensor_slices((x,y))
def func(x,y):
return x*1,y
dataset=dataset.map(func)
dataset=dataset.shuffle(2)
dataset=dataset.repeat()
dataset=dataset.batch(3)
iterator=dataset.make_initializable_iterator()
result_x,result_y=iterator.get_next()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(iterator.initializer,feed_dict={x_placeholder:x,y_placeholder:y})
# sess.run(iterator.initializer)
for _ in range(1):
result =sess.run([result_x,result_y])
print(result)
result=sess.run([result_x,result_y])
print(result)
在sess.run(iterator.initializer
时,就把输入数据传入就可以了