用while loop计算阶乘_tf.while_loop (TF r2.0) 使用指南

本文在 TensorFlow Core r2.0 下测试通过

本文关闭了 Eager mode

函数签名

tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    maximum_iterations=None,
    name=None
)

功能

重复 body,直到 cond 返回 True

例子

例一:从1加到10

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

i = tf.constant(0)
acc = tf.constant(0)
cond = lambda i, _: tf.less(i, 10)
body = lambda i, acc: (tf.add(i, 1), tf.add(acc, 1))
graph = tf.while_loop(c, b, [i, acc])

with tf.compat.v1.Session() as sess:
    _, acc = sess.run(graph)
    print(acc)

------

10

例二:shape_invariants 的作用

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

i = tf.constant(0)
acc = tf.ones([2, 2])
cond = lambda i, _: i < 2
body = lambda i, acc: [i+1, tf.concat([acc, acc], axis=0)]
graph = tf.while_loop(
    cond, body, loop_vars=[i, acc],
    shape_invariants=[i.get_shape(), tf.TensorShape([None, 2])]
)

with tf.compat.v1.Session() as sess:
    _, acc = sess.run(g)
    print(acc)

------

初始 acc = 
[[1. 1.]
 [1. 1.]]

翻两倍(2 -> 4 -> 8)

结果 acc = 
[[1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]
 [1. 1.]]

如果 loop_vars 中变量的 shape 会发生改变, while_loop 要求使用 shape_invariants 显式指出,否则会报错。

例二中, acc 的 shape 每轮循环都会改变(翻倍)

细节

parallel_iterations

while_loop 并不是严格的循环,它允许循环的多次迭代并行执行

  • 可以使用 parallel_iterations 控制最大并发度
  • 如果程序正确,对于任意大于0的 parallel_iterations,程序运行的结果是稳定的

swap_memory

训练时,TensorFlow 会存储前向计传播产生的 Tensor 用于反向传播,这一行为有时会耗尽显存。开启 swap_memory 后,TensorFlow 会将这部分 Tensor 移动到 CPU。在训练 Large Batch 和 Long Sequence 时,这一特性比较有用。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值