tf.nn.dynamic_rnn应用案例及RNN中数据填充sequence_length的理解

 1.RNN函数及参数说明

tf.nn.dynamic_rnn(
    cell,
    inputs,
    sequence_length=None,
    initial_state=None,
    dtype=None,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False,
    scope=None
)

动态构建一个RNN(循环神经网络)。RNN网络的创建分静态和动态两种,通过静态生成的RNN网络,生成过程所需要的时间会更长,静态RNN会按照样本的时间序列个数(n)展开,在图中创建(n)个序列的cell或cell中,动态RNN只创建样本中一个序列的RNN,其他序列数据都会通过循环来进入该RNN计算。静态RNN会占用更多的内存,导出的模型也会更大,模型中有每个序列中间态的信息,利于调试,在使用静态RNN时,必须与训练时的样本序列个数相同。动态RNN占用内存小,导出模型小,模型中只会有最后的状态,在使用中还支持不同序列个数。

参数:

  • cell :RNNCell实例
  • inputs :是一个tensor,如果time_major== False(默认),则张量的形状必须是:[batch_size,max_time,embed_size];如果time_major== True(默认),则张量的形状必须是:[max_time,batch_size,embed_size]
  • sequence_length: (可选)大小为[batch_size],数据的类型是int32/int64向量。如果当前时间步的index超过该序列的实际长度时,则该时间步不进行计算,RNN的state复制上一个时间步的,同时该时间步的输出全部为零。(Used to copy-through state and zero-out outputs when past a batch element’s sequence length.)(绿色字体部分后面再详细解释)
  • time_major: inputs 和outputs 张量的形状格式。如果为True,则这些张量都应该是(都会是)[max_time,batch_size,depth.]。如果为false,则这些张量都应该是(都会是)[batch_size,max_time,depth]。
  • dtype: 期望输出和初始化state的类型
  • initial_state: (可选)RNN的初始state(状态)。
  • scope:命名空间

示例代码:

import tensorflow as tf
import numpy as np

tf.reset_default_graph()                        #图重置
x = np.random.randn(5, 3, 4)                    #随机一个形状[5, 3, 4]的矩阵
x[1, 2:] = 0                                    
x_length = [3, 2, 3, 3, 1]                      #每一个序列的有效长度
print("矩阵 x:", x)
cell = tf.nn.rnn_cell.LSTMCell(num_units=5, state_is_tuple=True)        #包含有5个cell的LSTMcell

#创建动态RNN
outpus, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=x_length,
    inputs=x)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

#RNN的输出
o = sess.run(outpus)
#RNN的state
s = sess.run(last_states)
print("output.shape: ", o.shape)
print("output: \n", o)
print("last_o: \n", o[:, -1, :])  # it is the s.h

print("--" * 30)
# print("last_states.shape: ", s.shape)
print("last_states: \n", s)
print("last_state.c: \n", s.c)
print("last_states.h: \n", s.h)

示例的输出结果:

矩阵 x: [[[-0.2751823   0.62875351  0.08142847  0.0495135 ]
  [ 0.94518818 -1.08437101 -0.57259845 -1.15679631]
  [-0.19791739 -0.11366371  0.33196727 -2.95037402]]
 [[ 0.85933844 -1.26729369 -0.57044966 -0.24348295]
  [-0.6732532  -0.45629259  1.81781292 -0.85931593]
  [ 0.          0.          0.          0.        ]]
 [[ 0.55458312  0.4417129   0.21923621  0.19822528]
  [ 1.77390922 -1.11876087  0.02392109 -1.40040903]
  [ 1.11472234  0.05540735 -1.69082127 -0.70942592]]
 [[-1.13115574  1.31535108 -0.85216367 -1.49456911]
  [-0.01220003  1.95155256  0.08071599 -0.07094339]
  [-0.05802316 -0.09593088 -1.49409716 -1.20874127]]
 [[-0.65522406 -0.18979061  0.23646165  0.2793438 ]
  [ 1.0816631   0.09033521 -0.15652261 -0.33875969]
  [ 1.06508435 -1.79722593  0.33324731 -0.40462456]]]

output.shape:  (5, 3, 5)
output: 
 [[[-5.85462742e-02 -8.89051727e-02 -2.28658993e-02  2.66128034e-02
   -7.55145785e-05]
  [-6.84013892e-02  2.48924295e-01  7.79245155e-02 -1.11014269e-01
   -2.62299386e-02]
  [-3.24853092e-01  4.51116259e-01  4.64375879e-02 -1.77961139e-01
   -9.13509021e-02]]
 [[ 8.28080343e-02  2.18622594e-01  9.79521907e-02 -8.97420760e-02
    7.46208410e-03]
  [ 4.71283021e-02  1.60689486e-01  2.71381448e-02  9.04684446e-02
   -5.02056102e-02]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00]]
 [[-3.74208696e-02  5.97588579e-03 -6.88835980e-02 -8.74015371e-04
    9.03774497e-02]
  [-7.98153071e-02  3.60198083e-01 -1.43305154e-02 -1.48454462e-01
    3.01195219e-02]
  [-1.71999350e-01  5.55816955e-01 -2.21555223e-03 -2.68536099e-01
    9.84883712e-02]]
 [[-2.90373124e-01 -2.20932796e-01  1.78160625e-02 -1.13303889e-01
   -2.93868863e-01]
  [-2.94047579e-01 -3.25381890e-01 -2.96603985e-02 -1.38995839e-01
   -1.63629461e-01]
  [-3.87496091e-01 -1.88073840e-01  4.15822855e-02 -2.07960420e-01
   -2.19774137e-01]]
 [[ 6.56150662e-02 -4.79943937e-02  2.52785587e-02  7.85719525e-02
   -5.45248743e-02]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00]
  [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
    0.00000000e+00]]]
last_o: 
 [[-0.32485309  0.45111626  0.04643759 -0.17796114 -0.0913509 ]
 [ 0.          0.          0.          0.          0.        ]
 [-0.17199935  0.55581695 -0.00221555 -0.2685361   0.09848837]
 [-0.38749609 -0.18807384  0.04158229 -0.20796042 -0.21977414]
 [ 0.          0.          0.          0.          0.        ]]
------------------------------------------------------------
last_states: 
 LSTMStateTuple(c=array([[-0.6310722 ,  0.72561189,  0.16727535, -0.7120915 , -0.55199839],
       [ 0.1374107 ,  0.39603976,  0.09034975,  0.19272424, -0.20216476],
       [-0.25981531,  1.03403874, -0.00366755, -0.8580169 ,  0.21095968],
       [-0.67283588, -0.30168916,  0.06598855, -0.55158284, -0.469758  ],
       [ 0.14507145, -0.11368609,  0.05022219,  0.147962  , -0.10601435]]), h=array([[-0.32485309,  0.45111626,  0.04643759, -0.17796114, -0.0913509 ],
       [ 0.0471283 ,  0.16068949,  0.02713814,  0.09046844, -0.05020561],
       [-0.17199935,  0.55581695, -0.00221555, -0.2685361 ,  0.09848837],
       [-0.38749609, -0.18807384,  0.04158229, -0.20796042, -0.21977414],
       [ 0.06561507, -0.04799439,  0.02527856,  0.07857195, -0.05452487]]))
last_state.c: 
 [[-0.6310722   0.72561189  0.16727535 -0.7120915  -0.55199839]
 [ 0.1374107   0.39603976  0.09034975  0.19272424 -0.20216476]
 [-0.25981531  1.03403874 -0.00366755 -0.8580169   0.21095968]
 [-0.67283588 -0.30168916  0.06598855 -0.55158284 -0.469758  ]
 [ 0.14507145 -0.11368609  0.05022219  0.147962   -0.10601435]]
last_states.h: 
 [[-0.32485309  0.45111626  0.04643759 -0.17796114 -0.0913509 ]
 [ 0.0471283   0.16068949  0.02713814  0.09046844 -0.05020561]
 [-0.17199935  0.55581695 -0.00221555 -0.2685361   0.09848837]
 [-0.38749609 -0.18807384  0.04158229 -0.20796042 -0.21977414]
 [ 0.06561507 -0.04799439  0.02527856  0.07857195 -0.05452487]]

2.RNN的创建

本文是根据已有的库函数生成RNN网络,RNN网络的生成分单层RNN网络模型,和多层RNN网络模型

2.1创建单层RNN网络模型

# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
 
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
 
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
 
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

2.2创建多层RNN网络模型 

# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
 
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
 
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
                                   inputs=data,
                                   dtype=tf.float32)

 

3.RNN对输入数据填充的理解

        本文以动态RNN的创建进行说明,库函数对模型的主体已经写好了,或者说模型的影藏层已经完成,现在还差输入层和输出层。对于输入层来说,需要规范送入模型的数据格式,送入模型的数据的向量形状为 [batch_size, time_step, embedding_size],而通常我们得到的数据格式可能并不是这样的,因此,在数据预处理之后,需要将数据形状调整为上述向量形状,不然会直接报错。数据输入模型时,要求输入的向量要一致,因此需要填充,示例代码中

x[1, 2:] = 0

是对向量进行填充0操作,但是如果仅完成填充,在训练时不把这些填充的数据去掉的话,那么这些位置的向量也会学到特征,并进一步影响梯度计算。因此训练时需要跳过填充部分的计算。 这时就要用到了sequence_length 

dynamic_rnn 采取主动告知模型向量有效长度的方式来避免计算这些填充,即 sequence_length 。上述代码中,

x_length = [3, 2, 3, 3, 1] 

 表示第1个batch中时间步有效长度是3,第2个batch中时间步的有效长度是2,第3个batch中时间步的有效长度是3,第4个batch中时间步的有效长度是3,第5个batch中时间步的有效长度是1,所谓的batch就是输入向量 [batch_size, time_step, embedding_size]的batch_size,因此,x_length的大小即x_length = [batch_size]

4.RNN数据的输出理解

RNN的数据输出的向量形状为 [batch_size, time_step, cell_state_size],当中batch_size,time_step,是来自输入向量的前两维,cell_state_size是示例代码中

cell = tf.nn.rnn_cell.LSTMCell(num_units=5, state_is_tuple=True)  

num_units,也就是cell的个数,以示例代码为例:

output.shape:  (5, 3, 5)   输出形状的前两维5,3 来源与输入向量形状的batch_size, time_step,输出形状的最后一维5,来自cell_state_size

 

 

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值