Tensorflow中RNNCell源码解析

RNNCell本地文件的路径:~/anaconda3/envs/tensorflow/lib/python3.6/site-packages/tensorflow/python/ops/rnn_cell_impl.py

Github地址:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py
有兴趣的可以直接去查看实现源码。

RNNCell是所有cell的父类,LayerRNNCell的设计是为了方便参数的管理,其他的cell和wrapper等结构均是根据具体论文和相关应用场景设计实现的。
rnn_cell的继承关系:

LayerRNNCell(RNNCell)

BasicRNNCell(LayerRNNCell)

BasicLSTMCell(LayerRNNCell)
LSTMCell(LayerRNNCell)
GRUCell(LayerRNNCell)

rnn_cell wrapper 用于装载rnncell,warpper均继承自RNNCell,因为它们的实现是在rnncell的信息的输入、输出和信息的自环的过程中进行操作,该过程和rnn_cell有相同的自环操作过程。主要包含:

  • MultiRNNCell
  • DropoutWrapper
  • DeviceWrapper
  • ResidualWrapper

RNNCell

    class RNNCell(base_layer.Layer):

      def __call__(self, inputs, state, scope=None):

        if scope is not None:
          with vs.variable_scope(scope,
                                 custom_getter=self._rnn_get_variable) as scope:
            return super(RNNCell, self).__call__(inputs, state, scope=scope)
        else:
          with vs.variable_scope(vs.get_variable_scope(),
                                 custom_getter=self._rnn_get_variable):
            return super(RNNCell, self).__call__(inputs, state)

      def _rnn_get_variable(self, getter, *args, **kwargs):
        variable = getter(*args, **kwargs)
        trainable = (variable in tf_variables.trainable_variables() or
                     (isinstance(variable, tf_variables.PartitionedVariable) and
                      list(variable)[0] in tf_variables.trainable_variables()))
        if trainable and variable not in self._trainable_weights:
          self._trainable_weights.append(variable)
        elif not trainable and variable not in self._non_trainable_weights:
          self._non_trainable_weights.append(variable)
        return variable

      @property
      def state_size(self):
        raise NotImplementedError("Abstract method")

      @property
      def output_size(self):
        raise NotImplementedError("Abstract method")

      def build(self, _):
        pass

      def zero_state(self, batch_size, dtype):
        with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
          state_size = self.state_size
          return _zero_state_tensors(state_size, batch_size, dtype)

RNNCell是一个抽象的父类,其他的RNNcell都会继承该方法,然后具体实现其中的call()函数。从上面的定义中我们发现其主要有state_size和output_size两个属性,分别代表了隐藏层和输出层的维度。然后就是zero_state()和call()两个函数,分别用于初始化初始状态h0为全零向量和定义实际的RNNCell的操作(比如RNN就是一个激活,GRU的两个门,LSTM的三个门控等,不同的RNN的区别主要体现在这个函数)。有了这个抽象类,我们接下来看看tf中BasicRNNCell、GRUCell、BasicLSTMCell三个cell的定义方法,了解不同变种RNN模型的定义方式的区别和实现方法。

BasicRNNCell

首先是BasicRNNCell,就是我们常说的RNN
在这里插入图片描述
最简单的RNN结构如上图所示,可以看出BasicRNNCell中把state_size和output_size定义成相同,而且ht和output也是相同的(看call函数的输出是两个output,也就是其并未定义输出部分)。再看一下其主要功能实现就是call函数的第一行,就是input和前一时刻状态state经过一个线性函数在经过一个激活函数即可,也就是最普通的RNN定义方式。也就是说output = new_state = f(W * input + U * state + B)。

    class BasicRNNCell(RNNCell):

      def __init__(self, num_units, activation=None, reuse=None):
        super(BasicRNNCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        output = self._activation(_linear([inputs, state], self._num_units, True))
        return output, output

GRUCell

接下来我们看一下GRU的定义,相比BasicRNNCell只改变了call函数部分,增加了重置门和更新门两部分,分别由r和u表示。然后c表示要更新的状态值。其对应的图及公式如下所示:
在这里插入图片描述

    r = f(W1 * input + U1 * state + B1)
    u = f(W2 * input + U2 * state + B2)
    c = f(W3 * input + U3 * r * state + B3)
    h_new = u * h + (1 - u) * c

GRUCell的实现代码如下:

    class GRUCell(RNNCell):

      def __init__(self,
                   num_units,
                   activation=None,
                   reuse=None,
                   kernel_initializer=None,
                   bias_initializer=None):
        super(GRUCell, self).__init__(_reuse=reuse)
        self._num_units = num_units
        self._activation = activation or math_ops.tanh
        self._kernel_initializer = kernel_initializer
        self._bias_initializer = bias_initializer

      @property
      def state_size(self):
        return self._num_units

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        with vs.variable_scope("gates"):  # Reset gate and update gate.
          # We start with bias of 1.0 to not reset and not update.
          bias_ones = self._bias_initializer
          if self._bias_initializer is None:
            dtype = [a.dtype for a in [inputs, state]][0]
            bias_ones = init_ops.constant_initializer(1.0, dtype=dtype)
          value = math_ops.sigmoid(
              _linear([inputs, state], 2 * self._num_units, True, bias_ones,
                      self._kernel_initializer))
          r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
        with vs.variable_scope("candidate"):
          c = self._activation(
              _linear([inputs, r * state], self._num_units, True,
                      self._bias_initializer, self._kernel_initializer))
        new_h = u * state + (1 - u) * c
        return new_h, new_h

BasicLSTMCell

接下来再看一下BasicLSTMCell的实现方法,相比GRU,LSTM又多了一个输出门,而且又新增添了一个C表示其内部状态,然后将h和c以tuple的形式返回作为LSTM内部的状态变量。其内部结构和公式表示如下图所示:
在这里插入图片描述具体实现公式如下:
在这里插入图片描述
BasicLSTMCell的实现代码如下:

    class BasicLSTMCell(RNNCell):

      def __init__(self, num_units, forget_bias=1.0,
                   state_is_tuple=True, activation=None, reuse=None):

        super(BasicLSTMCell, self).__init__(_reuse=reuse)
        if not state_is_tuple:
          logging.warn("%s: Using a concatenated state is slower and will soon be "
                       "deprecated.  Use state_is_tuple=True.", self)
        self._num_units = num_units
        self._forget_bias = forget_bias
        self._state_is_tuple = state_is_tuple
        self._activation = activation or math_ops.tanh

      @property
      def state_size(self):
        return (LSTMStateTuple(self._num_units, self._num_units)
                if self._state_is_tuple else 2 * self._num_units)

      @property
      def output_size(self):
        return self._num_units

      def call(self, inputs, state):
        sigmoid = math_ops.sigmoid
        # Parameters of gates are concatenated into one multiply for efficiency.
        if self._state_is_tuple:
          c, h = state
        else:
          c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)

        concat = _linear([inputs, h], 4 * self._num_units, True)

        # i = input_gate, j = new_input, f = forget_gate, o = output_gate
        i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)

        new_c = (
            c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
        new_h = self._activation(new_c) * sigmoid(o)

        if self._state_is_tuple:
          new_state = LSTMStateTuple(new_c, new_h)
        else:
          new_state = array_ops.concat([new_c, new_h], 1)
        return new_h, new_state

从上面的代码可以发现,其与BasicRNNCell和GRUCell的区别也主要在call()函数上,不同的功能实现也都在call里面进行。不难发现,还有一个在三个累里面都频繁使用到的函数_linear(),这个函数的作用是什么呢,我们先来看一下其源码:

_linear函数

    def _linear(args,
                output_size,
                bias,
                bias_initializer=None,
                kernel_initializer=None):

      if args is None or (nest.is_sequence(args) and not args):
        raise ValueError("`args` must be specified")
      if not nest.is_sequence(args):
        args = [args]

      # Calculate the total size of arguments on dimension 1.
      total_arg_size = 0
      shapes = [a.get_shape() for a in args]
      for shape in shapes:
        if shape.ndims != 2:
          raise ValueError("linear is expecting 2D arguments: %s" % shapes)
        if shape[1].value is None:
          raise ValueError("linear expects shape[1] to be provided for shape %s, "
                           "but saw %s" % (shape, shape[1]))
        else:
          total_arg_size += shape[1].value

      dtype = [a.dtype for a in args][0]

      # Now the computation.
      scope = vs.get_variable_scope()
      with vs.variable_scope(scope) as outer_scope:
        weights = vs.get_variable(
            _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size],
            dtype=dtype,
            initializer=kernel_initializer)
        if len(args) == 1:
          res = math_ops.matmul(args[0], weights)
        else:
          res = math_ops.matmul(array_ops.concat(args, 1), weights)
        if not bias:
          return res
        with vs.variable_scope(outer_scope) as inner_scope:
          inner_scope.set_partitioner(None)
          if bias_initializer is None:
            bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
          biases = vs.get_variable(
              _BIAS_VARIABLE_NAME, [output_size],
              dtype=dtype,
              initializer=bias_initializer)
        return nn_ops.bias_add(res, biases)

这个函数的输入args就是[input, state],而output_size是输出层的大小,我们可以看到BasicRNNCell中,output_size就是_num_units,GRUCell中是2*_num_units,BasicLSTMCell中是4*_num_units,这是因为_linear中执行的是RNN中的几个等式的Wx+Uh+B的功能,但是不同的RNN中数量不同,比如LSTM中需要计算四次,然后直接把output_size定义为4*_num_units,再把输出进行拆分成四个变量即可~~

自定义RNNCell

看完BasicRNNCell 、GRUCell和BasicLSTMCell的实现方案,应该不难想象出自定义RNNCell的方法,那就是继承_LayerRNNCell这个抽象类,然后一定要实现__init__、build、__call__这三个函数就行了,其中在call函数中实现自己需要的功能即可。(注意:build只调用一次,在build中进行变量实例化,在call中实现具体的rnncell操作)。

参考博客:
tensorflow中RNNcell源码分析以及自定义RNNCell的方法
Tensorflow rnn_cell api 阅读笔记

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值