函数原型:
tf.while_loop(cond, body, loop_vars)
参数:
- cond: 函数类型, 作为循环结束判断条件, 返回值为 tf.bool 类型
- body: 函数类型, 作为循环体, 返回值长度 与 loop_vars 长度相同,比如 loop_vars 有3个参数,那么body的返回值也有3个;若 loop_vars 有5个参数,那么body的返回值就有5个;
- loop_vars: 传入cond 和 body 的参数, 注意哦,传入 cond 和 body 的都是这些个参数,loop_vars 全部参数先全部传入 cond 函数,如果判断为True, 就再把loop_vars 全部参数传入body 函数,返回值为更新后的loop_vars,长度为传入的loop_vars长度相同。
举例 1:
import tensorflow as tf
init_i = tf.constant(0)
def cond(i):
return tf.less(i, 10)
def body(i):
return tf.add(i, 1)
result_tf = tf.while_loop(cond, body, [init_i])
with tf.Session() as sess:
print(sess.run(result_tf))
# 输出为:
# 10
举例 2:
import tensorflow as tf
init_i = tf.constant(0)
init_A = tf.constant([[1, 1],
[1, 1]], tf.float32)
init_B = tf.constant([[2, 2],
[2, 2]], tf.float32)
def cond(i, A, B):
return tf.less(i, 2)
def body(i, A, B):
return tf.add(i, 1), A+A, A+B
i_tf, A_tf, B_tf = tf.while_loop(cond, body, [init_i, init_A, init_B])
with tf.Session() as sess:
i, A, B = sess.run([i_tf, A_tf, B_tf])
print('i:', i)
print('\nA:', A)
print('\nB:', B)
# 输出为:
# i: 2
#
# A: [[4. 4.]
# [4. 4.]]
#
# B: [[5. 5.]
# [5. 5.]]
- cond函数 和 body函数 的传入参数是一样的,都是 loop_vars
- body函数 的返回值和传入参数的长度是一样的,传入哪些就返回哪些,body函数只是更新了这些参数值
- loop_vars 的值一直在循环中更新,直到其不满足cond 函数