tensorflow使用计算图的模型,所以常规for循环在tensorflow其实是不起作用的。
所以tensorflow提供了while_loop函数:
tf.while_loop(
cond,
body,
loop_vars,
shape_invariants=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
name=None,
maximum_iterations=None,
return_same_structure=False
)
具体参数就不一一介绍了,可以通过api文档或者help查询了解。
这里提一下当参数中存在list类型的变量时会产生的问题,
当参数存在list时,很容易会得到一些问题,如:
ValueError: Number of inputs and outputs of body must match loop_vars: 1, 2
这是通常是因为在body里对list进行了一些append一类的增或删操作,导致参数shape不匹配。
下面给出一种解决方案:
import tensorflow as tf
out = tf.Variable([])
i = tf.constant(0)
def cond(i, _):
return i < 10
def body(i, out):
i = i + 1
out = tf.concat([out, [1.0]], 0)
return [i, out]
_, out = tf.while_loop(cond, body, [i, out], shape_invariants=[i.get_shape(), tf.TensorShape([None])])
sess = tf.Session()
sess.run(tf.global_variables_initializer())
res = sess.run([_, out])
print(res) # [10, array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], dtype=float32)]
这里指定了shape_invariants参数,因为在while_loop中不希望参数的shape发生变化,因此在这里指定好shape,给一个tf.TensorShape([None])即自动推断长度,而不是固定检查,这样就可以解决list的长度在变化的问题了。