tf.TensorArray (tensor动态数组)

一、基本介绍

和Tensor的主要区别是动态数组, 随时可以读取. 可以用在RNN每个时刻的states.
tf.TensorArray参数:

def __init__(self,
               dtype,
               size=None,
               dynamic_size=None,
               clear_after_read=None,
               tensor_array_name=None,
               handle=None,
               flow=None,
               infer_shape=True,
               element_shape=None,
               colocate_with_first_write_call=True,
               name=None):
  • dtype: 元素类型
  • size: 大小
  • dynamic_size: 如果为True, 则size满后仍然可write.

二、基本使用

import tensorflow as tf

a = tf.TensorArray(tf.float32, size=2, dynamic_size=True)
a = a.write(0, [0, 1])          # 这里的write需要赋值给对方.
a = a.write(1, [1, 0])
a = a.write(2, [1, 1])

read_value = a.read(0)      # 读取某个索引下的值.
stack_value = a.stack()
concat_value = a.concat()
gather_value = a.gather([1, 2])   # gather是look up的意思.


with tf.Session() as sess:
    print(read_value.eval())
    print(stack_value.eval())
    print(concat_value.eval())
    print(gather_value.eval())
  • read() 读取某一索引下的值,索引只能为单值,不能为list.
  • stack() 将整个TensorArray转为Tensor输出.
  • concat() 将整个TensorArray拉成(n, 1)
  • gather() look up的意思.
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值