tf.TensorArray和tf.while_loop组合使用

TensorArray

TensorArray可以看做是具有动态size功能的Tensor数组。通常都是跟while_loop或map_fn结合使用。
常用方法有

  • write(index,value):将value写入TensorArray的第index个位置
  • stack:将TensorArray中的值作为Tensor返回

while_loop

final_state = tf.while_loop(cond, loop_body, init_state),作用是循环处理某个变量,中间处理的结果用来进行下一次处理,最后输出经过数次加工的变量,由于TensorArray可以动态扩展,因此常用来存储中间结果。

  • cond:是一个函数,负责判断继续执行循环的条件。
  • loop_body:是每个循环体内执行的操作,负责对循环状态迸行更新。
  • init_state:为循环的起始状态,它可以包含多个Tensor 或者 TensorArray 。

如果用伪代码来表示运行逻辑的话,那 tf.while_loop 的功能与下面的代码相当 :

def while_loop(cond, loop_body, init_state): 
    state = init_state 
    while(cond(state)) :   # 使用cond函数判断是否达到循环结束条件。
        state = loop_body(state)   # 使用loop_body函数对state进行更新。
    return state 

例子:

import tensorflow as tf

def condition(time, output_ta_l):
    return tf.less(time, 3)  # 真值比较 time小于3返回True 否则False

def body(time, output_ta_l):
    output_ta_l = output_ta_l.write(time, [2.4, 3.5])
    return time + 1, output_ta_l

time = tf.constant(0)
output_ta = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
print(output_ta)
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x0000016135AB4D88>
result = tf.while_loop(condition, body, loop_vars=[time, output_ta])
last_time, last_out = result
final_out = last_out.stack()
print(last_time.numpy())
>>> 3
print(last_out) # 还未解析的TensorArray
>>> <tensorflow.python.ops.tensor_array_ops.TensorArray object at 0x000002189D415E08>
print(final_out.numpy()) # time从0到3,一共向last_out写了三次
>>> [[2.4 3.5]
	 [2.4 3.5]
	 [2.4 3.5]]

参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值