1.RNN函数及参数说明
tf.nn.dynamic_rnn(
cell,
inputs,
sequence_length=None,
initial_state=None,
dtype=None,
parallel_iterations=None,
swap_memory=False,
time_major=False,
scope=None
)
动态构建一个RNN(循环神经网络)。RNN网络的创建分静态和动态两种,通过静态生成的RNN网络,生成过程所需要的时间会更长,静态RNN会按照样本的时间序列个数(n)展开,在图中创建(n)个序列的cell或cell中,动态RNN只创建样本中一个序列的RNN,其他序列数据都会通过循环来进入该RNN计算。静态RNN会占用更多的内存,导出的模型也会更大,模型中有每个序列中间态的信息,利于调试,在使用静态RNN时,必须与训练时的样本序列个数相同。动态RNN占用内存小,导出模型小,模型中只会有最后的状态,在使用中还支持不同序列个数。
参数:
- cell :RNNCell实例
- inputs :是一个tensor,如果time_major== False(默认),则张量的形状必须是:[batch_size,max_time,embed_size];如果time_major== True(默认),则张量的形状必须是:[max_time,batch_size,embed_size]
- sequence_length: (可选)大小为[batch_size],数据的类型是int32/int64向量。如果当前时间步的index超过该序列的实际长度时,则该时间步不进行计算,RNN的state复制上一个时间步的,同时该时间步的输出全部为零。(Used to copy-through state and zero-out outputs when past a batch element’s sequence length.)(绿色字体部分后面再详细解释)
- time_major: inputs 和outputs 张量的形状格式。如果为True,则这些张量都应该是(都会是)[max_time,batch_size,depth.]。如果为false,则这些张量都应该是(都会是)[batch_size,max_time,depth]。
- dtype: 期望输出和初始化state的类型
- initial_state: (可选)RNN的初始state(状态)。
- scope:命名空间
示例代码:
import tensorflow as tf
import numpy as np
tf.reset_default_graph() #图重置
x = np.random.randn(5, 3, 4) #随机一个形状[5, 3, 4]的矩阵
x[1, 2:] = 0
x_length = [3, 2, 3, 3, 1] #每一个序列的有效长度
print("矩阵 x:", x)
cell = tf.nn.rnn_cell.LSTMCell(num_units=5, state_is_tuple=True) #包含有5个cell的LSTMcell
#创建动态RNN
outpus, last_states = tf.nn.dynamic_rnn(
cell=cell,
dtype=tf.float64,
sequence_length=x_length,
inputs=x)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
#RNN的输出
o = sess.run(outpus)
#RNN的state
s = sess.run(last_states)
print("output.shape: ", o.shape)
print("output: \n", o)
print("last_o: \n", o[:, -1, :]) # it is the s.h
print("--" * 30)
# print("last_states.shape: ", s.shape)
print("last_states: \n", s)
print("last_state.c: \n", s.c)
print("last_states.h: \n", s.h)
示例的输出结果:
矩阵 x: [[[-0.2751823 0.62875351 0.08142847 0.0495135 ]
[ 0.94518818 -1.08437101 -0.57259845 -1.15679631]
[-0.19791739 -0.11366371 0.33196727 -2.95037402]]
[[ 0.85933844 -1.26729369 -0.57044966 -0.24348295]
[-0.6732532 -0.45629259 1.81781292 -0.85931593]
[ 0. 0. 0. 0. ]]
[[ 0.55458312 0.4417129 0.21923621 0.19822528]
[ 1.77390922 -1.11876087 0.02392109 -1.40040903]
[ 1.11472234 0.05540735 -1.69082127 -0.70942592]]
[[-1.13115574 1.31535108 -0.85216367 -1.49456911]
[-0.01220003 1.95155256 0.08071599 -0.07094339]
[-0.05802316 -0.09593088 -1.49409716 -1.20874127]]
[[-0.65522406 -0.18979061 0.23646165 0.2793438 ]
[ 1.0816631 0.09033521 -0.15652261 -0.33875969]
[ 1.06508435 -1.79722593 0.33324731 -0.40462456]]]
output.shape: (5, 3, 5)
output:
[[[-5.85462742e-02 -8.89051727e-02 -2.28658993e-02 2.66128034e-02
-7.55145785e-05]
[-6.84013892e-02 2.48924295e-01 7.79245155e-02 -1.11014269e-01
-2.62299386e-02]
[-3.24853092e-01 4.51116259e-01 4.64375879e-02 -1.77961139e-01
-9.13509021e-02]]
[[ 8.28080343e-02 2.18622594e-01 9.79521907e-02 -8.97420760e-02
7.46208410e-03]
[ 4.71283021e-02 1.60689486e-01 2.71381448e-02 9.04684446e-02
-5.02056102e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00]]
[[-3.74208696e-02 5.97588579e-03 -6.88835980e-02 -8.74015371e-04
9.03774497e-02]
[-7.98153071e-02 3.60198083e-01 -1.43305154e-02 -1.48454462e-01
3.01195219e-02]
[-1.71999350e-01 5.55816955e-01 -2.21555223e-03 -2.68536099e-01
9.84883712e-02]]
[[-2.90373124e-01 -2.20932796e-01 1.78160625e-02 -1.13303889e-01
-2.93868863e-01]
[-2.94047579e-01 -3.25381890e-01 -2.96603985e-02 -1.38995839e-01
-1.63629461e-01]
[-3.87496091e-01 -1.88073840e-01 4.15822855e-02 -2.07960420e-01
-2.19774137e-01]]
[[ 6.56150662e-02 -4.79943937e-02 2.52785587e-02 7.85719525e-02
-5.45248743e-02]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00]
[ 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00]]]
last_o:
[[-0.32485309 0.45111626 0.04643759 -0.17796114 -0.0913509 ]
[ 0. 0. 0. 0. 0. ]
[-0.17199935 0.55581695 -0.00221555 -0.2685361 0.09848837]
[-0.38749609 -0.18807384 0.04158229 -0.20796042 -0.21977414]
[ 0. 0. 0. 0. 0. ]]
------------------------------------------------------------
last_states:
LSTMStateTuple(c=array([[-0.6310722 , 0.72561189, 0.16727535, -0.7120915 , -0.55199839],
[ 0.1374107 , 0.39603976, 0.09034975, 0.19272424, -0.20216476],
[-0.25981531, 1.03403874, -0.00366755, -0.8580169 , 0.21095968],
[-0.67283588, -0.30168916, 0.06598855, -0.55158284, -0.469758 ],
[ 0.14507145, -0.11368609, 0.05022219, 0.147962 , -0.10601435]]), h=array([[-0.32485309, 0.45111626, 0.04643759, -0.17796114, -0.0913509 ],
[ 0.0471283 , 0.16068949, 0.02713814, 0.09046844, -0.05020561],
[-0.17199935, 0.55581695, -0.00221555, -0.2685361 , 0.09848837],
[-0.38749609, -0.18807384, 0.04158229, -0.20796042, -0.21977414],
[ 0.06561507, -0.04799439, 0.02527856, 0.07857195, -0.05452487]]))
last_state.c:
[[-0.6310722 0.72561189 0.16727535 -0.7120915 -0.55199839]
[ 0.1374107 0.39603976 0.09034975 0.19272424 -0.20216476]
[-0.25981531 1.03403874 -0.00366755 -0.8580169 0.21095968]
[-0.67283588 -0.30168916 0.06598855 -0.55158284 -0.469758 ]
[ 0.14507145 -0.11368609 0.05022219 0.147962 -0.10601435]]
last_states.h:
[[-0.32485309 0.45111626 0.04643759 -0.17796114 -0.0913509 ]
[ 0.0471283 0.16068949 0.02713814 0.09046844 -0.05020561]
[-0.17199935 0.55581695 -0.00221555 -0.2685361 0.09848837]
[-0.38749609 -0.18807384 0.04158229 -0.20796042 -0.21977414]
[ 0.06561507 -0.04799439 0.02527856 0.07857195 -0.05452487]]
2.RNN的创建
本文是根据已有的库函数生成RNN网络,RNN网络的生成分单层RNN网络模型,和多层RNN网络模型
2.1创建单层RNN网络模型
# create a BasicRNNCell
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
# defining initial state
initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
initial_state=initial_state,
dtype=tf.float32)
2.2创建多层RNN网络模型
# create 2 LSTMCells
rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
# create a RNN cell composed sequentially of a number of RNNCells
multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
# 'outputs' is a tensor of shape [batch_size, max_time, 256]
# 'state' is a N-tuple where N is the number of LSTMCells containing a
# tf.contrib.rnn.LSTMStateTuple for each cell
outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
inputs=data,
dtype=tf.float32)
3.RNN对输入数据填充的理解
本文以动态RNN的创建进行说明,库函数对模型的主体已经写好了,或者说模型的影藏层已经完成,现在还差输入层和输出层。对于输入层来说,需要规范送入模型的数据格式,送入模型的数据的向量形状为 [batch_size, time_step, embedding_size],而通常我们得到的数据格式可能并不是这样的,因此,在数据预处理之后,需要将数据形状调整为上述向量形状,不然会直接报错。数据输入模型时,要求输入的向量要一致,因此需要填充,示例代码中
x[1, 2:] = 0
是对向量进行填充0操作,但是如果仅完成填充,在训练时不把这些填充的数据去掉的话,那么这些位置的向量也会学到特征,并进一步影响梯度计算。因此训练时需要跳过填充部分的计算。 这时就要用到了sequence_length
dynamic_rnn 采取主动告知模型向量有效长度的方式来避免计算这些填充,即 sequence_length 。上述代码中,
x_length = [3, 2, 3, 3, 1]
表示第1个batch中时间步有效长度是3,第2个batch中时间步的有效长度是2,第3个batch中时间步的有效长度是3,第4个batch中时间步的有效长度是3,第5个batch中时间步的有效长度是1,所谓的batch就是输入向量 [batch_size, time_step, embedding_size]的batch_size,因此,x_length的大小即x_length = [batch_size]
4.RNN数据的输出理解
RNN的数据输出的向量形状为 [batch_size, time_step, cell_state_size],当中batch_size,time_step,是来自输入向量的前两维,cell_state_size是示例代码中
cell = tf.nn.rnn_cell.LSTMCell(num_units=5, state_is_tuple=True)
num_units,也就是cell的个数,以示例代码为例:
output.shape: (5, 3, 5) 输出形状的前两维5,3 来源与输入向量形状的batch_size, time_step,输出形状的最后一维5,来自cell_state_size