tensorflow中RNNcell源码分析以及自定义RNNCell的方法

本文深入探讨TensorFlow中RNNCell的使用,包括BasicRNNCell、GRUCell和BasicLSTMCell的实现原理,并介绍了如何自定义RNNCell,以适应Recurrent Entity Networks和Neural Turing Machines等复杂模型的需求。
摘要由CSDN通过智能技术生成

我们在仿真一些论文的时候经常会遇到一些模型,对RNN或者LSTM进行了少许的修改,或者自己定义了一种RNN的结构等情况,比如前面介绍的几篇memory networks的论文,往往都需要按照自己定义的方法来构造RNN网络。所以本篇博客就主要总结一下RNNcell的用法以及如何按照自己的需求自定义RNNCell。

tf中RNNCell的用法介绍

我们直接从源码的层面来看一看tf是如何实现RNNCell定义的。代码入下:

    class RNNCell(base_layer.Layer):

      def __call__(self, inputs, state, scope=None):

        if scope is not None:
          with vs.variable_scope(scope,
                                 custom_getter=self._rnn_get_variable) as scope:
            return super(RNNCell, self).__call__(inputs, state, scope=scope)
        else:
          with vs.variable_scope(vs.get_variable_scope(),
                                 custom_getter=self._rnn_get_variable):
            return super(RNNCell, self).__call__(inputs, state)

      def _rnn_get_variable(self, getter, *args, **kwargs):
        variable = getter(*args, **kwargs)
        trainable = (variable in tf_variables.trainable_variables() or
                     (isinstance(variable, tf_variables.PartitionedVariable) and
                      list(variable)[0] in tf_variables.trainable_variables()))
        if trainable and variable not in self._trainable_weights:
          self._trainable_weights.append(variable)
        elif not trainable and variable not in self._non_trainable_weights:
          self._non_trainable_weights.append(variable)
        return variable

      @property
      def state_size(self):
        raise NotImplementedError("Abstract method")

      @property
      def output_size(self):
        raise NotImplementedError("Abstract method")

      def build(self, _):
        pass

      def zero_state(self, batch_size, dtype):
        with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
          state_size = self.state_size
          return _zero_state_tensors(state_size, batch_size, dtype)

RNNCell是一个抽象的父类,其他的RNNcell都会继承该方法,然后具体实现其中的call()函数。从上面的定义中我们发现其主要有state_size和output_size两个属性,分别代表了隐藏层和输出层的维度。然后就是zero_state()和call()两个函数,分别用于初始化初始状态h0为全零向量和定义实际的RNNCell的操作(比如RNN就是一个激活,GRU的两个门,LSTM的三个门控等,不同的RNN的区别主要体现在这个函数)。有了这个抽象类,我们接下来看看tf中BasicRNNCell、GRUCell、BasicLSTMCell三个cell的定义方法,了解不同变种RNN模型的定义方式的区别和实现方法。

    class BasicRNNCell(RNNCell):

      def __init__(self, num_units, activation=None, reuse=None):
        super(BasicRNNCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        output = self._activation(_linear([inputs, state], self._num_units, True))
        return output, output

最简单的RNN结构如上图所示,可以看出BasicRNNCell中把state_size和output_size定义成相同,而且ht和output也是相同的(看call函数的输出是两个output,也就是其并未定义输出部分)。再看一下其主要功能实现就是call函数的第一行,就是input和前一时刻状态state经过一个线性函数在经过一个激活函数即可,也就是最普通的RNN定义方式。也就是说output = new_state = f(W * input + U * state + B)。接下来我们看一下GRU的定义:

    class GRUCell(RNNCell):

      def __init__(self,
                   num_units,
                   activation=None,
                   reuse=None,
                   kernel_initializer=None,
                   bias_initializer=None):
        super(GRUCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh
        self._kernel_initializer = kernel_initializer
        self._bias_initializer = bias_initializer

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        with vs.variable_scope("gates"):  # Reset gate and update gate.
          # We start with bias of 1.0 to not reset and not update.
          bias_ones = self._bias_initializer
          if self._bias_initializer is None:
            dtype = [a.dtype for a in [inputs, state]][0]
            bi
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值