本文主要讲述tensorflow中迭代器的用法
import tensorflow as tf
dateset1 = tf.data.Dataset.from_tensor_slices(tf.random_uniform([4, 10], 0, 8))
print(dateset1.output_types) #"tf.float32"
print(dateset1.output_shapes) # (10,)
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(100):
value = sess.run(next_element)
assert i == value, "caca"
else:
print("guoyi")
max_value = tf.placeholder(tf.int64, shape=[])
# 定义的dataset有参数,只能使用参数化迭代器
dateset = tf.data.Dataset.range(max_value)
# 定义参数化迭代器
iterator = dateset.make_initializable_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
# 需要用参数初始化迭代器
sess.run(iterator.initializer, feed_dict={max_value: 10})
for i in range(10):
value = sess.run(next_element)
assert i == value
with tf.Session() as sess:
# 可以再次使用迭代器
sess.run(iterator.initializer, feed_dict={max_value: 100})
for i in range(100):
value = sess.run(next_element)
assert i == value