我们都知道RNN是一种循环神经网络,其实可以认为是一种递归计算,每一个时刻的输出都是根据上一个时刻的输出和本时刻的输入得到:
H t + 1 = f ( H t , x t + 1 ) H_{t+1}= f(H_t, x_{t+1}) Ht+1=f(Ht,xt+1)
那么在tensorflow是如何实现这种递归计算的呢?
RNNCell + dynamic_rnn
结合上图和下面的代码,首先,我们需要自定义一个RNN类,并继承tf.nn.rnn_cell.RNNCell
-
其中,必须实现
state_size
和output_size
,这里其实分别是上个时刻传递给下一时刻的隐藏层size(一般用state表示)和每个时刻的输出size; -
然后,要是实现
__call__
方法,输入必须是inputs和state:这里的inputs([batch_size, dims])是每个时刻的输入;
state([batch_size, state_size])是上一时刻的state。
-
并且必须返回两个东西:
第一个是本时刻的输出(代码中为output),还有本时刻的state(对应代码的new_alphas)。
-
最后,我们在使用这个递归计算RNN类的时候,先实例化,然后通过
tf.nn.dynamic_rnn
调用:
tf.nn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths,
initial_state=first_input,
dtype=tf.float32)
cell:递归计算RNN对象
inputs:[batch_size, max_seq_len, dims]
sequence_length:[batch_size]
initial_state:[batch_size, state_size],初设时刻的state,即为上图的state(t=0)
我们可以这么理解这个过程:
首先时刻t=1:根据初设化的state即state(t=0)和输入X(t=1),计算state(t=1)和输出y(t=1);
接着时刻t=2:根据上个时刻的state即state(t=1)和输入X(t=2),计算state(t=2)和输出y(t=2);
…
最后tf.nn.dynamic_rnn返回:
output([batch_size, max_seq_len, output_size])为所有时刻的输出即 ( y ( t = 1 ) , y ( t = 2 ) , . . . . . , y ( t = n ) ) (y(t=1), y(t=2), ..... , y(t=n)) (y(t=1),y(t=2),.....,y(t=n))
state([batch_size, state_size])为最后一个时刻的state即 s t a t e ( t = n ) state(t=n) state(t=n)
import numpy as np
import tensorflow as tf
class MyRnnCell(tf.nn.rnn_cell.RNNCell):
def __init__(self, transition_params):
"""Initialize the CrfForwardRnnCell.
Args:
transition_params: A [num_tags, num_tags] matrix of binary potentials.
This matrix is expanded into a [1, num_tags, num_tags] in preparation
for the broadcast summation occurring within the cell.
"""
self._transition_params = tf.expand_dims(transition_params, 0)
self._num_tags = transition_params.get_shape()[0].value
@property
def state_size(self):
return self._num_tags
@property
def output_size(self):
return self._num_tags
def __call__(self, inputs, state, scope=None):
"""Build the CrfForwardRnnCell.
Args:
inputs: A [batch_size, num_tags] matrix of unary potentials.
state: A [batch_size, num_tags] matrix containing the previous alpha
values.
scope: Unused variable scope of this cell.
Returns:
new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
values containing the new alpha values.
"""
state = tf.expand_dims(state, 2)
# This addition op broadcasts self._transitions_params along the zeroth
# dimension and state along the second dimension. This performs the
# multiplication of previous alpha values and the current binary potentials
# in log space.
transition_scores = state + self._transition_params
new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1])
output = new_alphas
# Both the state and the output of this RNN cell contain the alphas values.
# The output value is currently unused and simply satisfies the RNN API.
# This could be useful in the future if we need to compute marginal
# probabilities, which would require the accumulated alpha values at every
# time step.
return output, new_alphas
def rnn_test(inputs, transition_params, sequence_lengths):
# Split up the first and rest of the inputs in preparation for the forward
# algorithm.
first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
first_input = tf.squeeze(first_input, [1])
"""Forward computation of alpha values."""
rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])
# Compute the alpha values in the forward algorithm in order to get the
# partition function.
forward_cell = MyRnnCell(transition_params)
output, state = tf.nn.dynamic_rnn(
cell=forward_cell,
inputs=rest_of_input,
sequence_length=sequence_lengths,
initial_state=first_input,
dtype=tf.float32)
return output, state
if __name__ == '__main__':
inputs_arr = np.random.random([20, 10, 5])
tag_indices_arr = np.random.randint(0, 5, [20, 10])
transition_params_arr = np.random.random([5, 5])
sequence_lengths_arr = np.random.randint(0, 10, [20])
inputs = tf.placeholder(tf.float32, [None, 10, 5])
transition_params = tf.placeholder(tf.float32, [5, 5])
sequence_lengths = tf.placeholder(tf.int64, [None])
feed_dict = {inputs: inputs_arr, sequence_lengths: sequence_lengths_arr,
transition_params: transition_params_arr}
sess = tf.Session()
output, state = rnn_test(inputs, transition_params, sequence_lengths)
print(sess.run([output, state], feed_dict=feed_dict))