记录一个花费很久的Error
自己的代码比较复杂点,所以直接引用链接的东西
https://github.com/tensorflow/tensorflow/issues/30435
直接运行以下代码(应该是tf1)
import tensorflow as tf
class A():
def __init__(self):
self.lst = [1, 2, 3]
self.sess = tf.Session()
self.total_length = tf.constant(len(self.lst))
def loop(self, i):
pr = tf.print(i)
current_value = self.lst[i.eval(session=self.sess)]##关键代码 标记为A
#current_value = self.lst[i] ##如果这样写不报错,且正常输出 标记为B
with tf.control_dependencies([pr]):
i = tf.add(i, 1)
return [i]
def cond(self, i):
return tf.less(i, self.total_length)
def run(self):
i = tf.constant(0)
while_op = tf.while_loop(self.cond, self.loop, [i])
final_i = self.sess.run(while_op)
if __name__ == "__main__":
obj = A()
obj.run()
报错
代码A的意思很明显,是因为i是个tf张量类型,想把i转成数字作为数组索引,输出数组的值。
代码B直接引用了tf张量类型的i作为索引,不知道为什么可以正常运行。
报错的原因,解释是说不能在while_loop里面使用session或eval把张量转成numpy数组。
tensorflow1不支持。