tf.data迭代器问题

import tensorflow as tf 
import numpy as np

def __a(a):
    b=a+1
    b=np.squeeze(b)
    return b 

a=np.array(range(5))
b=tf.constant(a)
dataset =  tf.data.Dataset.from_tensor_slices(a)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=len(a), count=-1))
dataset= dataset.apply(tf.contrib.data.map_and_batch(
                map_func=lambda c: tf.py_func(__a, [c], [tf.int64]),
                batch_size=1))

print (a)
print ('wwwwwwwwwwwww')
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
with tf.Session() as sess:
    for q in range(5):
        for i in range(5):
            value = sess.run(next_element)
            print (value)
            print ('xxx')
        print ('qqq')

在 for i in range(5):的时候,没问题,每一次大迭代都会遍历a中的元素,也就是0~4。

但是把这句话改为for i in range(2):的时候,就会变成如下图

 

也就是说前5次迭代还是会遍历0~4。 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值