本文在 TensorFlow Core r2.0 下测试通过
本文关闭了 Eager mode
函数签名
tf.while_loop(
cond,
body,
loop_vars,
shape_invariants=None,
parallel_iterations=10,
back_prop=True,
swap_memory=False,
maximum_iterations=None,
name=None
)
功能
重复 body
,直到 cond
返回 True
例子
例一:从1加到10
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
i = tf.constant(0)
acc = tf.constant(0)
cond = lambda i, _: tf.less(i, 10)
body = lambda i, acc: (tf.add(i, 1), tf.add(acc, 1))
graph = tf.while_loop(c, b, [i, acc])
with tf.compat.v1.Session() as sess:
_, acc = sess.run(graph)
print(acc)
------
10
例二:shape_invariants 的作用
import tensorflow as tf
tf.compat.v1.disable_eager_execution()
i = tf.constant(0)
acc = tf.ones([2, 2])
cond = lambda i, _: i < 2
body = lambda i, acc: [i+1, tf.concat([acc, acc], axis=0)]
graph = tf.while_loop(
cond, body, loop_vars=[i, acc],
shape_invariants=[i.get_shape(), tf.TensorShape([None, 2])]
)
with tf.compat.v1.Session() as sess:
_, acc = sess.run(g)
print(acc)
------
初始 acc =
[[1. 1.]
[1. 1.]]
翻两倍(2 -> 4 -> 8)
结果 acc =
[[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]
[1. 1.]]
如果 loop_vars 中变量的 shape 会发生改变, while_loop 要求使用 shape_invariants 显式指出,否则会报错。
例二中, acc
的 shape 每轮循环都会改变(翻倍)
细节
parallel_iterations
while_loop
并不是严格的循环,它允许循环的多次迭代并行执行
- 可以使用
parallel_iterations
控制最大并发度 - 如果程序正确,对于任意大于0的
parallel_iterations
,程序运行的结果是稳定的
swap_memory
训练时,TensorFlow 会存储前向计传播产生的 Tensor 用于反向传播,这一行为有时会耗尽显存。开启 swap_memory
后,TensorFlow 会将这部分 Tensor 移动到 CPU。在训练 Large Batch 和 Long Sequence 时,这一特性比较有用。