神经翻译笔记4扩展a第一部分. RNN在TF1.x中的实现方法略览
本文主要讨论TF1.14对RNN的实现。尽管更老的版本实现可能有差别,但考虑其有些过时,因此这里略过
RNN cell的实现
TF1.x对RNN cell的实现总体遵循如下所示之类图
keras.layers.Layer
代码位置:tensorflow/python/keras/engine/base_layer.py
该类是所有层(“layers”)的基类。按照文档,所谓“层”实现了神经网络的一些常用操作。既然如此,那么一个“层”应该有如下特点:
- 首先,既然定义了“操作”,那么肯定会有一个“宾语”。这意味着每个层都会接收输入,作为操作对象
- 其次,“操作”本身一定有逻辑,因此每个层内部会定义该层对输入应如何处理,执行怎样的计算过程。通常情况下,每个层都需要设置一些可训练的参数
- 最后,该操作通常会产生输出,将输出传递给下一个层,作为下一个层的输入
由于层的核心是计算过程,因此在这样的基类中,最核心的方法是覆写的__call__
方法,使得所有继承该类的子类对象是可调用的。此类已经对__call__
方法做了比较好的包装,因此所有具体的层只需要实现该类提供的接口build
和call
即可。这两个接口会在__call__
方法中调用。基类实现的__call__
方法逻辑大致为
def __call__(inputs):
将inputs中所有numpy类型数据转换为tensor
build_graph = 判断是否建图
previous_mask = 从前层获取mask值
with base_layer_utils.call_context(build_graph):
if build_graph:
判断输入是否满足input_spec
# 下述核心逻辑在静态图(build_graph is True)和eager模式下是相同的
# 但因为两者背后实现逻辑有区别,所以原始代码分开做了实现。这里做了合并
self._maybe_build(inputs) # 实际就是调用build。该方法会在开始检查self._built是否为True
# 如果为True直接返回。在最后会显式将self._built设置为True
outputs = self.call(inputs)
# 对输出做正则,并将正则损失加到loss中
self._handle_activity_regularization(inputs, outputs)
# 计算并设置mask
self._set_mask_metadata(inputs, outputs, previous_mask)
return outputs
此外,子类一般还需要实现__init__
方法。子类所需要实现的三个方法大致分工如下
__init__
来创建并初始化(一部分)成员变量,但是不指定训练参数build
内通常调用add_weight
方法,根据输入、类型、形状(shape)、使用的初始化方法等信息创建要训练的参数call
内部实现具体的计算逻辑,返回outputs
如果损失计算或权重更新与输入有关,可通过让子类实现add_loss
或add_update
以达到此目的
layers.Layer
代码位置:tensorflow/python/layers/base.py
该类存在的主要目的是向下兼容静态图模式的代码。在TF1.x的早期版本(至少到1.5)该类事实上是所有“层”类的基类(直接继承自object
),不过在1.x的后期版本该类的核心逻辑已被移动到前述tf.keras.layers.Layer
中,官方也不再推荐开发者继承该类开发
nn.rnn_cell.RNNCell
代码位置:tensorflow/python/ops/rnn_cell_impl.py
。至另有标注为止,后续各类的实现均在该文件中
该类为所有具体的RNN实现提供了一个共同的抽象表示。通过重写__call__
方法,将其签名改为__call__(self, inputs, state, scope=None)
,使得该类的所有子类对象以函数的方式被“调用”时,参数除了inputs
以外还需要再带一个state
参数(通常是上一个时间步传递来的状态)。对应地,具体实现call
时也需要传入这两个参数(当然,实际上仍然可以只写明传入inputs
,而state
通过args, **kwargs
传入。但是这样看上去像是在杠)。此外,该类还提供了get_initial_state
和zero_state
方法,后者常用来初始化RNN的初始状态
LayerRNNCell
为了向layers.Layer
靠拢插进来的新类。其作用是将变量创建(build
该做的事情)从call
中剥离出来。具体的解(tu)释(cao)可以参看为什么感觉tensorflow的源码写的很多余? - Towser的回答 - 知乎
BasicRNNCell
BasicRNNCell
实现的是没有任何门控的Vanilla RNN。我一直不很倾向在文字里放大段原始代码,但是由于该类背后逻辑比较简单,实现相对简短,因此这里将会给出原始实现的具体细节
__init__
函数设置自身的input_spec
、单元数num_units
和激活函数种类,默认激活函数为tanh。如代码文档所说,这里的“cell”和文献里的“cell”不同。文献里的cell输出一个标量,但是这里的一个cell相当于一组文献中的cell,共num_units
个。考虑之前文章贴的示意图
其展开图如下
这里RNN cell的num_units
即为3(每个时间步黄色圆圈的数量)
__init__
的具体实现如下:
def __init__(self, num_units, activation=None, reuse=None,
name=None, dtype=None, **kwargs):
super(BasicRNNCell, self).__init__(_reuse=reuse, name=name,
dtype=dtype, **kwargs)
# 只接受浮点型或复数浮点型输入
_check_supported_dtypes(self.dtype)
if context.executing_eagerly() and context.num_gpus() > 0:
logging.warn("建议使用tf.contrib.cudnn_rnn.CudnnRNNTanh以达到更好性能")
# 要求输入是2维。由本文前面介绍,会在调用__call__时检查input spec
self.input_spec = input_spec.InputSpec(ndim=2)
self._num_units = num_units
if activation:
self._activation = activations.get(activation)
else:
self._activation = math_ops.tanh
build
只是通过最基类tf.Layers
提供的add_weight
方法来注册变量。add_weight
在调用时会访问初始化时提供的信息,例如参数使用何种初始化方法、是否可训练等,并将参数加入图中(静态图模式下)。build
的具体实现为
def build(self, inputs_shape):
if inputs_shape[-1] is None:
raise ValueError
_check_supported_dtypes(self.dtype)
# 与词向量连接的那一层,input_depth就是词向量维度
input_depth = input_shape[-1]
# add_variable就是add_weight的别名
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units])
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
self.built = True
call
则是实现计算逻辑
o u t p u t = s t a t e = a c t i v a t e ( x ( t ) U + h ( t − 1 ) W + b ) {output} = {state} = {\rm activate}\left(\boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W}+\boldsymbol{b}\right) output=state=activate(x(t)U+h(t−1)W+b)
其中
- x t \boldsymbol{x}_t xt形状为 b a t c h _ s i z e × i n p u t _ d e p t h batch\_size \times input\_depth batch_size×input_depth,对应代码中的
inputs
(这里使用小写,是取batch size为1时的情况。此时输入是一个行向量) - h t − 1 \boldsymbol{h}_{t-1} ht−1形状为 b a t c h _ s i z e × n u m _ u n i t s batch\_size \times num\_units batch_size×num_units,对应代码中的
state
- U \boldsymbol{U} U形状为 i n p u t _ d e p t h × n u m _ u n i t s input\_depth \times num\_units input_depth×num_units
- W \boldsymbol{W} W形状为 n u m _ u n i t s × n u m _ u n i t s num\_units \times num\_units num_units×num_units
(这里矩阵乘法顺序与理论部分有区别,原因是理论部分将输入向量看做列向量,而这里使用行向量实现。下同)
真正实现时,将 U \boldsymbol{U} U与 W \boldsymbol{W} W“纵向”连接在一起(这个矩阵的整体为self._kernel
), x ( t ) \boldsymbol{x}^{(t)} x(t)与 h ( t − 1 ) \boldsymbol{h}^{(t-1)} h(t−1)按“横向”连接在一起。这样仅通过一次矩阵乘法运算,就能算出结果。即
[ x ( t ) h ( t − 1 ) ] ⋅ [ U W ] = x ( t ) U + h ( t − 1 ) W \left[\begin{matrix}\boldsymbol{x}^{(t)} & \boldsymbol{h}^{(t-1)}\end{matrix}\right] \cdot \left[\begin{matrix}\boldsymbol{U} \\ \boldsymbol{W} \end{matrix}\right] = \boldsymbol{x}^{(t)}\boldsymbol{U} + \boldsymbol{h}^{(t-1)}\boldsymbol{W} [x(t)h(t−1)]⋅[UW]=x(t)U+h(t−1)W
这也是为什么在build
里self._kernel
的形状为[input_depth + self._num_units, self._num_units]
的原因
具体代码如下
def call(self, inputs, state):
_check_rnn_cell_input_dtypes([inputs, state])
gate_inputs = math_ops.matmul(
array_ops.concat([inputs, state], 1), self._kernel)
gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)
output = self._activation(gate_inputs)
# 实际为output, state。但是vanilla RNN不区分
return output, output
GRUCell
GRUCell
和后面要介绍的LSTMCell
实现思路与前述