过程
尝试在tf.sess中使用dataset。
首先直接用的以下方法制作dataset:
dataset = tf.data.Dataset.from_tensor_slices((imgs, lables))
dataset = dataset.map(read_img_func).batch(batch_size)
运行时用的:
for img, label in input_data:
res = sess.run(self._out, feed_dict={self._input: img})
报错:
RuntimeError: `tf.data.Dataset` only supports Python-style iteration in eager mode or within tf.function.
随后换成以下代码:
iterators = input_data.make_one_shot_iterator()
img, label = iterators.get_next()
iterators = input_data.make_initializable_iterator()
img, label = iterators.get_next()
报错:
'BatchDataset' object has no attribute 'make_one_shot_iterator'
'BatchDataset' object has no attribute 'make_initializable_iterator'
看了github上的源码dataset_op,似乎应该替换成tf.compat.v1.data.make_one_shot_iterator(dataset)
:
iterator = tf.compat.v1.data.make_initializable_iterator(input_data)
# iterators = input_data.make_one_shot_iterator()
img, label = iterator.get_next()
res_list = []
while True:
try:
import pdb
pdb.set_trace()
res = sess.run(self._out, feed_dict={self._input: img})
except tf.errors.OutOfRangeError:
print("iterator done")
break
报错:
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles. For reference, the tensor object was Tensor("IteratorGetNext:0", shape=(None, 224, 224, 3), dtype=float32) which was passed to the argument `feed_dict` with key Tensor("Placeholder_137:0", shape=(None, 224, 224, 3), dtype=float32).
看样子在feed_dict里面传入的值加一个.eval(session=sess)就可以了。
res = sess.run(self._out, feed_dict={self._input: img.eval(session=sess)})
然后又报错:
GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
加个这个就行:sess.run(iterator.initializer)
结论
with tf.device(‘/gpu:0’):
iterator = tf.compat.v1.data.make_initializable_iterator(input_data)
# iterators = input_data.make_one_shot_iterator()
sess.run(iterator.initializer)
img, label = iterator.get_next()
res_list = []
while True:
try:
res, _input = sess.run([self._out,self._input], feed_dict={self._input: img.eval(session=sess)})
res_list.append(res)
except tf.errors.OutOfRangeError:
print("iterator done")
break
res = np.concatenate(res_list,axis=0)