【tensorflow】‘BatchDataset‘ object has no attribute ‘make_one_shot_iterator‘

过程

尝试在tf.sess中使用dataset。
首先直接用的以下方法制作dataset:

    dataset = tf.data.Dataset.from_tensor_slices((imgs, lables))
    dataset = dataset.map(read_img_func).batch(batch_size)

运行时用的:

for img, label in input_data:
	res = sess.run(self._out, feed_dict={self._input: img})

报错:

RuntimeError: `tf.data.Dataset` only supports Python-style iteration in eager mode or within tf.function.

随后换成以下代码:

iterators = input_data.make_one_shot_iterator()
img, label = iterators.get_next()

iterators = input_data.make_initializable_iterator()
img, label = iterators.get_next()

报错:

'BatchDataset' object has no attribute 'make_one_shot_iterator' 
'BatchDataset' object has no attribute 'make_initializable_iterator'

看了github上的源码dataset_op,似乎应该替换成tf.compat.v1.data.make_one_shot_iterator(dataset):

iterator = tf.compat.v1.data.make_initializable_iterator(input_data)
# iterators = input_data.make_one_shot_iterator()
img, label = iterator.get_next()
res_list = []
while True:
    try:
        import pdb
        pdb.set_trace()
        res = sess.run(self._out, feed_dict={self._input: img})
    except tf.errors.OutOfRangeError:
        print("iterator done")
        break

报错:

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles. For reference, the tensor object was Tensor("IteratorGetNext:0", shape=(None, 224, 224, 3), dtype=float32) which was passed to the argument `feed_dict` with key Tensor("Placeholder_137:0", shape=(None, 224, 224, 3), dtype=float32).

看样子在feed_dict里面传入的值加一个.eval(session=sess)就可以了。

res = sess.run(self._out, feed_dict={self._input: img.eval(session=sess)})

然后又报错:

GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.

加个这个就行:sess.run(iterator.initializer)

结论

with tf.device(/gpu:0):
        iterator = tf.compat.v1.data.make_initializable_iterator(input_data)
        # iterators = input_data.make_one_shot_iterator()
        sess.run(iterator.initializer)
        img, label = iterator.get_next()
        res_list = []
        while True:
            try:
                res, _input = sess.run([self._out,self._input], feed_dict={self._input: img.eval(session=sess)})
                res_list.append(res)
            except tf.errors.OutOfRangeError:
                print("iterator done")
                break
        res = np.concatenate(res_list,axis=0)
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

canmoumou

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值