import tensorflow as tf
import numpy as np
def __a(a):
b=a+1
b=np.squeeze(b)
return b
a=np.array(range(5))
b=tf.constant(a)
dataset = tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=len(a), count=-1))
dataset= dataset.apply(tf.contrib.data.map_and_batch(
map_func=lambda c: tf.py_func(__a, [c], [tf.int64]),
batch_size=1))
print (a)
print ('wwwwwwwwwwwww')
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
for q in range(5):
for i in range(5):
value = sess.run(next_element)
print (value)
print ('xxx')
print ('qqq')
在 for i in range(5):的时候,没问题,每一次大迭代都会遍历a中的元素,也就是0~4。
但是把这句话改为for i in range(2):的时候,就会变成如下图
也就是说前5次迭代还是会遍历0~4。