tensorflow学习之static_rnn使用详解

tf.nn.static_rnn

Aliases:

  1. tf.contrib.rnn.static_rnn
  2. tf.nn.static_rnn

使用指定的RNN神经元创建循环神经网络

tf.nn.static_rnn(

    cell,

    inputs,

    initial_state=None,

    dtype=None,

    sequence_length=None,

    scope=None

)

参数说明:

  • cell:用于神经网络的RNN神经元,如BasicRNNCell,BasicLSTMCell
  • inputs:一个长度为T的list,list中的每个元素为一个Tensor,Tensor形如:[batch_size,input_size]
  • initial_state:RNN的初始状态,如果cell.state_size是一个整数,则它必须是适当类型和形如[batch_size,cell.state_size]的张量。如cell.state_size是一个元组,那么它应该是一个张量元组,对于cell.state_size中的s,应该是具有形如[batch_size,s]的张量的元组。
  • dtype:初始状态和预期输出的数据类型。可选参数。
  • sequence_length:指定每个输入的序列的长度。大小为batch_size的向量。
  • scope:变量范围

返回值:

一个(outputs,state)对

outputs:一个长度为T的list,list中的每个元素是每个输入对应的输出。例如一个时间步对应一个输出。

state:最终的状态

代码实例:

import tensorflow as tf



x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim]

x=tf.unstack(x,axis=1) #按时间步展开

n_neurons = 5 #输出神经元数量



basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)

output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)



print(len(output_seqs)) #四个时间步

print(output_seqs[0]) #每个时间步输出一个张量

print(output_seqs[1]) #每个时间步输出一个张量

print(states) #隐藏状态

输出如下:

4

Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)

Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)

Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)

 

  • 15
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值