TensorFlow-tf.nn.dynamic_rnn函数解析

眼看千遍,不如手动一遍,看了原文再手动整理一遍,代码实际操作一遍,加深理解。相当于高中时做的笔记了。

tf.nn.dynamic_rnn 函数是tensorflow封装的用来实现递归神经网络(RNN)的函数,本文会重点讨论一下tf.nn.dynamic_rnn 函数的参数及返回值。

首先来看一下该函数定义:

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

重要参数介绍:

  • cell:LSTM、GRU等的记忆单元。cell参数代表一个LSTM或GRU的记忆单元,也就是一个cell。例如,cell = tf.nn.rnn_cell.LSTMCell((num_units),其中,num_units表示rnn cell中神经元个数,也就是下文的cell.output_size。返回一个LSTM或GRU cell,作为参数传入。
  • inputs:输入的训练或测试数据,一般格式为[batch_size, max_time, embed_size],其中batch_size是输入的这批数据的数量,max_time就是这批数据中序列的最长长度,embed_size表示嵌入的词向量的维度。
  • sequence_length:是一个list,假设你输入了三句话,且三句话的长度分别是5,10,25,那么sequence_length=[5,10,25]。
  • time_major:决定了输出tensor的格式,如果为True, 张量的形状必须为 [max_time, batch_size,cell.output_size]。如果为False, tensor的形状必须为[batch_size, max_time, cell.output_size],cell.output_size表示rnn cell中神经元个数
  • 返回值:元组(outputs, states)

       outputs:outputs很容易理解,就是每个cell会有一个输出

       states:states表示最终的状态,也就是序列中最后一个cell输出的状态。一般情况下states的形状为 [batch_size, cell.output_size ],但当输入的cell为BasicLSTMCell时,state的形状为[2,batch_size, cell.output_size ],其中2也对应着LSTM中的cell state和hidden state。

那为什么states输出形状会有变化呢?states和outputs又有什么关系呢?

【1】对于第一问题“states”形状为什么会发生变化呢?

       我们以LSTMGRU分别为tf.nn.dynamic_rnn的输入cell类型为例,

       当cell为LSTM,states形状为[2,batch_size, cell.output_size ];

       当cell为GRU时,states形状为[batch_size, cell.output_size ]。

       其原因是因为LSTM和GRU的结构本身不同,如下面两个图所示,这是LSTM的cell结构,每个cell会有两个输出: C_{t}和 h_{t},上面这个图是输出C_{t},代表哪些信息应该被记住哪些应该被遗忘; 下面这个图是输出h_{t},代表这个cell的最终输出,LSTM的states是由 C_{t}和 h_{t}组成的,即states = (c,h)。

 

     当cell为GRU时,state就只有一个了,原因是GRU将  C_{t}和 h_{t}进行了简化,将其合并成了h_{t},如下图所示,GRU将遗忘门和输入门合并成了更新门,另外cell不再有细胞状态cell state,只有hidden state。

 

对于第二个问题outputs和states有什么关系?

       如果cell为LSTM,那 states是个tuple,分别代表 C_{t}和 h_{t},其中h_{t}与outputs中对应的最后一个时刻(即最后一个cell)的输出相等;如果cell为GRU,那么同理,states其实就是h_{t}

通过测试代码可以更好的理解。

import tensorflow as tf
import numpy as np
 
def dynamic_rnn(rnn_type='lstm'):
    # 创建输入数据,3代表batch size,6代表输入序列的最大步长(max time),4代表每个序列的维度
    X = np.random.randn(3, 6, 4)
 
    # 第二个输入的实际长度为4
    X[1, 4:] = 0
 
    #记录三个输入的实际步长
    X_lengths = [6, 4, 6]
 
    rnn_hidden_size = 5
    if rnn_type == 'lstm':
        cell = tf.contrib.rnn.BasicLSTMCell(num_units=rnn_hidden_size, state_is_tuple=True)
    else:
        cell = tf.contrib.rnn.GRUCell(num_units=rnn_hidden_size)
 
    outputs, last_states = tf.nn.dynamic_rnn(
        cell=cell,
        dtype=tf.float64,
        sequence_length=X_lengths,
        inputs=X)
 
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        o1, s1 = session.run([outputs, last_states])
        print(np.shape(o1))
        print(o1)
        print(np.shape(s1))
        print(s1)
 
if __name__ == '__main__':
    dynamic_rnn(rnn_type='lstm')

输出结果:

outputs,维度是[3,6,5],3是batch_size的大小,6是代表输入序列的最大步长(max time),5是rnn_hidden_size = 5

states,维度是[2,3,5],2是表示一个C_{t},一个 h_{t},3是batch_size的大小,5是rnn_hidden_size = 5

实验一解析:cell类型为LSTM,我们看看输出是什么样子,如上图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ]states形状为 [ 2, 3, 5 ],其中state第一部分为c,代表cell state;第二部分为h,代表hidden state。可以看到hidden state 与 对应的outputs的最后一行是相等的(rnn_hidden_size = 5)。另外需要注意的是输入一共有三个序列,但第二个序列的长度只有4,可以看到outputs中对应的两行值都为0,所以hidden state对应的是最后一个不为0的部分。tf.nn.dynamic_rnn通过设置sequence_length来实现这一逻辑,本例中的 X_lengths =[6,4,6]。

 

将上面代码输入改为dynamic_rnn(rnn_type='gru'),继续运行

输出结果:

outputs,维度是[3,6,5],3是batch_size的大小,6是代表输入序列的最大步长(max time),5是rnn_hidden_size = 5

states,维度是[3,5],因为gru不像lstm有C_{t}和 h_{t},3是batch_size的大小,5是rnn_hidden_size = 5

实验二解析:cell类型为GRU,我们看看输出是什么样子,如上图所示,输入的形状为 [ 3, 6, 4 ],经过tf.nn.dynamic_rnn后outputs的形状为 [ 3, 6, 5 ],state形状为 [ 3, 5 ]。可以看到 state 与 对应的outputs的最后一行是相等的。

 

 

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值