神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览

本文深入探讨TensorFlow 1.14中RNN的实现,包括RNN cell的基础结构、静态与动态机制,并通过代码示例解析Vanilla RNN、GRU和LSTM的实现细节。重点介绍了RNN cell的实现逻辑和RNN信息的时序流动方式。
摘要由CSDN通过智能技术生成

本文主要讨论TF1.14对RNN的实现。尽管更老的版本实现可能有差别,但考虑其有些过时,因此这里略过

RNN cell的实现

TF1.x对RNN cell的实现总体遵循如下所示之类图
TF1.x中各RNNCell实现的关系图

keras.layers.Layer

代码位置:tensorflow/python/keras/engine/base_layer.py

该类是所有层(“layers”)的基类。按照文档,所谓“层”实现了神经网络的一些常用操作。既然如此,那么一个“层”应该有如下特点:

  • 首先,既然定义了“操作”,那么肯定会有一个“宾语”。这意味着每个层都会接收输入,作为操作对象
  • 其次,“操作”本身一定有逻辑,因此每个层内部会定义该层对输入应如何处理,执行怎样的计算过程。通常情况下,每个层都需要设置一些可训练的参数
  • 最后,该操作通常会产生输出,将输出传递给下一个层,作为下一个层的输入

由于层的核心是计算过程,因此在这样的基类中,最核心的方法是覆写的__call__方法,使得所有继承该类的子类对象是可调用的。此类已经对__call__方法做了比较好的包装,因此所有具体的层只需要实现该类提供的接口buildcall即可。这两个接口会在__call__方法中调用。基类实现的__call__方法逻辑大致为

def __call__(inputs):
    将inputs中所有numpy类型数据转换为tensor
    build_graph = 判断是否建图
    previous_mask = 从前层获取mask值
    with base_layer_utils.call_context(build_graph):
        if build_graph:
            判断输入是否满足input_spec
        # 下述核心逻辑在静态图(build_graph is True)和eager模式下是相同的
        # 但因为两者背后实现逻辑有区别,所以原始代码分开做了实现。这里做了合并
        self._maybe_build(inputs)  # 实际就是调用build。该方法会在开始检查self._built是否为True
                                   # 如果为True直接返回。在最后会显式将self._built设置为True
        outputs = self.call(inputs)
        # 对输出做正则,并将正则损失加到loss中
        self._handle_activity_regularization(inputs, outputs)
        # 计算并设置mask
        self._set_mask_metadata(inputs, outputs, previous_mask)
        return outputs

此外,子类一般还需要实现__init__方法。子类所需要实现的三个方法大致分工如下

  • __init__来创建并初始化(一部分)成员变量,但是不指定训练参数
  • build内通常调用add_weight方法,根据输入、类型、形状(shape)、使用的初始化方法等信息创建要训练的参数
  • call内部实现具体的计算逻辑,返回outputs

如果损失计算或权重更新与输入有关,可通过让子类实现add_lossadd_update以达到此目的

layers.Layer

代码位置:tensorflow/python/layers/base.py

该类存在的主要目的是向下兼容静态图模式的代码。在TF1.x的早期版本(至少到1.5)该类事实上是所有“层”类的基类(直接继承自object),不过在1.x的后期版本该类的核心逻辑已被移动到前述tf.keras.layers.Layer中,官方也不再推荐开发者继承该类开发

nn.rnn_cell.RNNCell

代码位置:tensorflow/python/ops/rnn_cell_impl.py。至另有标注为止,后续各类的实现均在该文件中

该类为所有具体的RNN实现提供了一个共同的抽象表示。通过重写__call__方法,将其签名改为__call__(self, inputs, state, scope=None),使得该类的所有子类对象以函数的方式被“调用”时,参数除了inputs以外还需要再带一个state参数(通常是上一个时间步传递来的状态)。对应地,具体实现call时也需要传入这两个参数(当然,实际上仍然可以只写明传入inputs,而state通过args, **kwargs传入。但是这样看上去像是在杠)。此外,该类还提供了get_initial_statezero_state方法,后者常用来初始化RNN的初始状态

LayerRNNCell

为了向layers.Layer靠拢插进来的新类。其作用是将变量创建(build该做的事情)从call中剥离出来。具体的解(tu)释(cao)可以参看为什么感觉tensorflow的源码写的很多余? - Towser的回答 - 知乎

BasicRNNCell

BasicRNNCell实现的是没有任何门控的Vanilla RNN。我一直不很倾向在文字里放大段原始代码,但是由于该类背后逻辑比较简单,实现相对简短,因此这里将会给出原始实现的具体细节

__init__函数设置自身的input_spec、单元数num_units和激活函数种类,默认激活函数为tanh。如代码文档所说,这里的“cell”和文献里的“cell”不同。文献里的cell输出一个标量,但是这里的一个cell相当于一组文献中的cell,共num_units个。考虑之前文章贴的示意图
RNN示意图
其展开图如下
RNN详细展开图这里RNN cell的num_units即为3(每个时间步黄色圆圈的数量)
__init__的具体实现如下:

def __init__(self, num_units, activation=None, reuse=None, 
             name=None, dtype=None, **kwargs):
    super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, 
                                       dtype=dtype, **kwargs)
    # 只接受浮点型或复数浮点型输入
    _check_supported_dtypes(self.dtype)
    if context.executing_eagerly() and context.num_gpus() > 0:
        logging.warn("建议使用tf.contrib.cudnn_rnn.CudnnRNNTanh以达到更好性能")

    # 要求输入是2维。由本文前面介绍,会在调用__call__时检查input spec
    self.input_spec = input_spec.InputSpec(ndim=2)

    self._num_units = num_units
    if activation:
        self._activation = activations.get(activation)
    else:
        self._activation = math_ops.tanh

build只是通过最基类tf.Layers提供的add_weight方法来注册变量。add_weight在调用时会访问初始化时提供的信息,例如参数使用何种初始化方法、是否可训练等,并将参数加入图中(静态图模式下)。build的具体实现为

def build(self, inputs_shape):
    if inputs_shape[-1] is None:
        raise ValueError
    _check_supported_dtypes(self.dtype)

    # 与词向量连接的那一层,input_depth就是词向量维度
    input_depth = input_shape[-1]
    # add_variable就是add_weight的别名
    self._kernel = self.add_variable(
        _WEIGHTS_VARIABLE_NAME,
        shape=[input_depth + self._num_units, self._num_units])
    self._bias = self.add_variable(
        _BIAS_VARIABLE_NAME,
        shape=[self._num_units],
        initializer=init_ops.zeros_initializer(dtype=self.dtype))
    self.built = True

call则是实现计算逻辑
o u t p u t = s t a t e = a c t i v a t e ( x ( t ) U + h ( t − 1 ) W + b ) {output} = {state} = {\rm activate}\left(\boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W}+\boldsymbol{b}\right) output=state=activate(x(t)U+h(t1)W+b)
其中

  • x t \boldsymbol{x}_t xt形状为 b a t c h _ s i z e × i n p u t _ d e p t h batch\_size \times input\_depth batch_size×input_depth,对应代码中的inputs(这里使用小写,是取batch size为1时的情况。此时输入是一个行向量)
  • h t − 1 \boldsymbol{h}_{t-1} ht1形状为 b a t c h _ s i z e × n u m _ u n i t s batch\_size \times num\_units batch_size×num_units,对应代码中的state
  • U \boldsymbol{U} U形状为 i n p u t _ d e p t h × n u m _ u n i t s input\_depth \times num\_units input_depth×num_units
  • W \boldsymbol{W} W形状为 n u m _ u n i t s × n u m _ u n i t s num\_units \times num\_units num_units×num_units

(这里矩阵乘法顺序与理论部分有区别,原因是理论部分将输入向量看做列向量,而这里使用行向量实现。下同)

真正实现时,将 U \boldsymbol{U} U W \boldsymbol{W} W“纵向”连接在一起(这个矩阵的整体为self._kernel), x ( t ) \boldsymbol{x}^{(t)} x(t) h ( t − 1 ) \boldsymbol{h}^{(t-1)} h(t1)按“横向”连接在一起。这样仅通过一次矩阵乘法运算,就能算出结果。即
[ x ( t ) h ( t − 1 ) ] ⋅ [ U W ] = x ( t ) U + h ( t − 1 ) W \left[\begin{matrix}\boldsymbol{x}^{(t)} & \boldsymbol{h}^{(t-1)}\end{matrix}\right] \cdot \left[\begin{matrix}\boldsymbol{U} \\ \boldsymbol{W} \end{matrix}\right] = \boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W} [x(t)h(t1)][UW]=x(t)U+h(t1)W
这也是为什么在buildself._kernel的形状为[input_depth + self._num_units, self._num_units]的原因

具体代码如下

def call(self, inputs, state):
    _check_rnn_cell_input_dtypes([inputs, state])
    gate_inputs = math_ops.matmul(
        array_ops.concat([inputs, state], 1), self._kernel)
    gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
    output = self._activation(gate_inputs)
    # 实际为output, state。但是vanilla RNN不区分
    return output, output

GRUCell

GRUCell和后面要介绍的LSTMCell实现思路与前述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值