一、基本介绍
和Tensor的主要区别是动态数组, 随时可以读取. 可以用在RNN每个时刻的states.
tf.TensorArray参数:
def __init__(self,
dtype,
size=None,
dynamic_size=None,
clear_after_read=None,
tensor_array_name=None,
handle=None,
flow=None,
infer_shape=True,
element_shape=None,
colocate_with_first_write_call=True,
name=None):
- dtype: 元素类型
- size: 大小
- dynamic_size: 如果为True, 则size满后仍然可write.
二、基本使用
import tensorflow as tf
a = tf.TensorArray(tf.float32, size=2, dynamic_size=True)
a = a.write(0, [0, 1]) # 这里的write需要赋值给对方.
a = a.write(1, [1, 0])
a = a.write(2, [1, 1])
read_value = a.read(0) # 读取某个索引下的值.
stack_value = a.stack()
concat_value = a.concat()
gather_value = a.gather([1, 2]) # gather是look up的意思.
with tf.Session() as sess:
print(read_value.eval())
print(stack_value.eval())
print(concat_value.eval())
print(gather_value.eval())
- read() 读取某一索引下的值,索引只能为单值,不能为list.
- stack() 将整个TensorArray转为Tensor输出.
- concat() 将整个TensorArray拉成(n, 1)
- gather() look up的意思.