TensorArray
TensorArray
可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
常用方法有
write(index,value)
:将value写入TensorArray的第index个位置stack
:将TensorArray中的值作为Tensor返回
while_loop
final_state = tf.while_loop(cond, loop_body, init_state)
,作用是循环处理某个变量,中间处理的结果用来进行下一次处理,最后输出经过数次加工的变量,由于TensorArray可以动态扩展,因此常用来存储中间结果。
cond
:是一个函数,负责判断继续执行循环的条件。loop_body
:是每个循环体内执行的操作,负责对循环状态迸行更新。init_state
:为循环的起始状态,它可以包含多个Tensor 或者 TensorArray 。
如果用伪代码来表示运行逻辑的话,那 tf.while_loop 的功能与下面的代码相当 :
def while_loop(cond, loop_body, init_state):
state = init_state
while(cond(state)) : # 使用cond函数判断是否达到循环结束条件。
state = loop_body(state) # 使用loop_body函数对state进行更新。
return state
例子:
import tensorflow as tf
def condition(time, output_ta_l):
return tf.less(time, 3) # 真值比较 time小于3返回True 否则False
def body(time, output_ta_l):
output_ta_l = output_ta_l.write(time, [2.4, 3.5])
return time + 1, output_ta_l
time = tf.constant(0)
output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
print(output_ta)
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x0000016135AB4D88>
result = tf.while_loop(condition, body, loop_vars=[time, output_ta])
last_time, last_out = result
final_out = last_out.stack()
print(last_time.numpy())
>>> 3
print(last_out) # 还未解析的TensorArray
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x000002189D415E08>
print(final_out.numpy()) # time从0到3,一共向last_out写了三次
>>> [[2.4 3.5]
[2.4 3.5]
[2.4 3.5]]
参考