这是一个使用tensorflow while循环的cumulative_max的实现,它需要进行n = len(x)次迭代.该代码是可复制粘贴运行的示例.
import tensorflow as tf
def tf_while_condition(x, loop_counter):
return tf.not_equal(loop_counter, 0)
def tf_while_body(x, loop_counter):
loop_counter -= 1
y = tf.concat(([x[0]], x[:-1]), axis=0)
z = tf.maximum(x, y)
return z, loop_counter
x = tf.constant([0,2,5,3,8,1,7])
cumulative_max, _ = tf.while_loop(cond=tf_while_condition,
body=tf_while_body,
loop_vars=(x, x.shape[0]))
with tf.Session() as sess:
print(sess.run(cumulative_max))
结果:
[0 2 5 5 8 8 8]
注意:如果要计算的向量很大,并且不需要反向传播,则可能值得在tf.while_loop中包含back_prop = False.
理解TF while循环的关键是要了解基于python的函数tf_while_condition和tf_while_body仅被调用一次以产生相关的tensorflow操作.这两个函数不在循环中调用.他们返回的操作将在sess.run计算期间在张量流图中的循环中执行.