tensorflow是如何实现RNN的递归计算

我们都知道RNN是一种循环神经网络,其实可以认为是一种递归计算,每一个时刻的输出都是根据上一个时刻的输出和本时刻的输入得到:

H t + 1 = f ( H t , x t + 1 ) H_{t+1}= f(H_t, x_{t+1}) Ht+1=f(Ht,xt+1)

那么在tensorflow是如何实现这种递归计算的呢?

RNNCell + dynamic_rnn

在这里插入图片描述

结合上图和下面的代码,首先,我们需要自定义一个RNN类,并继承tf.nn.rnn_cell.RNNCell

  1. 其中,必须实现state_sizeoutput_size,这里其实分别是上个时刻传递给下一时刻的隐藏层size(一般用state表示)和每个时刻的输出size;

  2. 然后,要是实现__call__方法,输入必须是inputs和state:

    这里的inputs([batch_size, dims])是每个时刻的输入;

    state([batch_size, state_size])是上一时刻的state。

  3. 并且必须返回两个东西:

    第一个是本时刻的输出(代码中为output),还有本时刻的state(对应代码的new_alphas)。

  4. 最后,我们在使用这个递归计算RNN类的时候,先实例化,然后通过tf.nn.dynamic_rnn调用:

tf.nn.dynamic_rnn(
        cell=forward_cell,
        inputs=rest_of_input,
        sequence_length=sequence_lengths,
        initial_state=first_input,
        dtype=tf.float32)

cell:递归计算RNN对象
inputs:[batch_size, max_seq_len, dims]
sequence_length:[batch_size]
initial_state:[batch_size, state_size],初设时刻的state,即为上图的state(t=0)

我们可以这么理解这个过程:

首先时刻t=1:根据初设化的state即state(t=0)和输入X(t=1),计算state(t=1)和输出y(t=1);

接着时刻t=2:根据上个时刻的state即state(t=1)和输入X(t=2),计算state(t=2)和输出y(t=2);

最后tf.nn.dynamic_rnn返回:

output([batch_size, max_seq_len, output_size])为所有时刻的输出即 ( y ( t = 1 ) , y ( t = 2 ) , . . . . . , y ( t = n ) ) (y(t=1), y(t=2), ..... , y(t=n)) (y(t=1),y(t=2),.....,y(t=n))

state([batch_size, state_size])为最后一个时刻的state即 s t a t e ( t = n ) state(t=n) state(t=n)

import numpy as np
import tensorflow as tf


class MyRnnCell(tf.nn.rnn_cell.RNNCell):

    def __init__(self, transition_params):
        """Initialize the CrfForwardRnnCell.

        Args:
          transition_params: A [num_tags, num_tags] matrix of binary potentials.
              This matrix is expanded into a [1, num_tags, num_tags] in preparation
              for the broadcast summation occurring within the cell.
        """
        self._transition_params = tf.expand_dims(transition_params, 0)
        self._num_tags = transition_params.get_shape()[0].value

    @property
    def state_size(self):
        return self._num_tags

    @property
    def output_size(self):
        return self._num_tags

    def __call__(self, inputs, state, scope=None):
        """Build the CrfForwardRnnCell.

        Args:
          inputs: A [batch_size, num_tags] matrix of unary potentials.
          state: A [batch_size, num_tags] matrix containing the previous alpha
              values.
          scope: Unused variable scope of this cell.

        Returns:
          new_alphas, new_alphas: A pair of [batch_size, num_tags] matrices
              values containing the new alpha values.
        """
        state = tf.expand_dims(state, 2)

        # This addition op broadcasts self._transitions_params along the zeroth
        # dimension and state along the second dimension. This performs the
        # multiplication of previous alpha values and the current binary potentials
        # in log space.
        transition_scores = state + self._transition_params
        new_alphas = inputs + tf.reduce_logsumexp(transition_scores, [1])
        output = new_alphas

        # Both the state and the output of this RNN cell contain the alphas values.
        # The output value is currently unused and simply satisfies the RNN API.
        # This could be useful in the future if we need to compute marginal
        # probabilities, which would require the accumulated alpha values at every
        # time step.
        return output, new_alphas


def rnn_test(inputs, transition_params, sequence_lengths):
    # Split up the first and rest of the inputs in preparation for the forward
    # algorithm.
    first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
    first_input = tf.squeeze(first_input, [1])

    """Forward computation of alpha values."""
    rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

    # Compute the alpha values in the forward algorithm in order to get the
    # partition function.
    forward_cell = MyRnnCell(transition_params)

    output, state = tf.nn.dynamic_rnn(
        cell=forward_cell,
        inputs=rest_of_input,
        sequence_length=sequence_lengths,
        initial_state=first_input,
        dtype=tf.float32)

    return output, state


if __name__ == '__main__':
    inputs_arr = np.random.random([20, 10, 5])
    tag_indices_arr = np.random.randint(0, 5, [20, 10])
    transition_params_arr = np.random.random([5, 5])
    sequence_lengths_arr = np.random.randint(0, 10, [20])

    inputs = tf.placeholder(tf.float32, [None, 10, 5])
    transition_params = tf.placeholder(tf.float32, [5, 5])
    sequence_lengths = tf.placeholder(tf.int64, [None])

    feed_dict = {inputs: inputs_arr, sequence_lengths: sequence_lengths_arr,
                 transition_params: transition_params_arr}

    sess = tf.Session()
    output, state = rnn_test(inputs, transition_params, sequence_lengths)
    print(sess.run([output, state], feed_dict=feed_dict))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值