TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
例子1:将[2.4, 3.5]写进TensorArray三次
import tensorflow as tf
def condition(time, output_ta_l):
return tf.less(time, 3)
def body(time, output_ta_l):
output_ta_l = output_ta_l.write(time, [2.4, 3.5])
return time + 1, output_ta_l
time = tf.constant(0)
output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
result = tf.while_loop(condition, body, loop_vars=[time, output_ta])
last_time, last_out = result
final_out = last_out.stack()
with tf.Session():
print(last_time.eval())
print(final_out.eval())
Out:
3
[[ 2.4000001 3.5 ]
[ 2.4000001 3.5 ]
[ 2.4000001 3.5 ]]
重要函数:
ta.stack(name=None)
将TensorArray中元素叠起来当做一个Tensor输出
ta.unstack(value, name=None)
可以看做是stack的反操作,输入Tensor,输出一个新的TensorArray对象
ta.write(index, value, name=None)
指定index位置写入Tensor
ta.read(index, name=None)
读取指定index位置的Tensor
以上所有函数的参数name=None
均可用来指定当前操作的名称。