神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览
文章目录
与TF1.x的实现思路不同,在TF2.0中,RNN已经不再是个函数,而是一个封装好的类,各种RNN和RNNCell与顶层抽象类
Layer
的关系也更加紧凑(需要说明的是说
Layer
顶层并非说它直接继承自
object
,而是从……功能的角度,我觉得可以这么说。真实实现里的继承关系是
Layer --> Module --> AutoTrackable --> Trackable --> object
)。但是另一方面,感觉新的版本里各个类的关系稍微有些杂乱,不知道后面会不会进一步重构。TF2.0的RNN相关各类关系大致如下图所示
相关基类
tf.keras.layers.Layer
与TF1.14的实现基本相同,不再赘述
recurrent.DropoutRNNCellMixin
与之类似的类在TF1.x中以tf.nn.rnn_cell.DropoutWrapper
形式出现,但当时考虑到还没涉及到RNN的dropout就没有引入,没想到在这里还是要说一说。TF2的实现比TF1的实现要简单一些,这个类只是维护两个dropout mask,一个是用于对输入的mask,一个用于对传递状态的mask(严格说是四个,在另一个维度上还考虑是对静态图的mask还是对eager模式的mask)。实现保证mask只被创建一次,因此每个batch使用的mask都相同
RNNCell相关
无论是官方给出的文本分类教程,还是我自己从TF1.x改的用更底层API实现的代码,实际上都没有用到Cell相关的对象。但是为了完整起见(毕竟暴露的LSTM
类背后还需要LSTMCell
类对象作为自己的成员变量),这里还是稍作介绍
LSTMCell
本文以LSTM为主,因此先从LSTMCell
说起。与TF1.x不同,在2.x版本里,LSTMCell
允许传入一个implement
参数,默认为1,标记LSTM各门和输出、状态的计算方式。当取默认的1时,计算方式更像是论文中的方式,逐个计算各个门的结果;而如果设为2,则使用TF1.x中组合成矩阵一并计算的方式。此外,由于LSTMCell
还继承了前述DropoutRNNCellMixin
接口,因此可以在call
里对输入和上一时间步传来的状态做dropout。注意由于LSTM有四个内部变量 i \boldsymbol{i} i、 f \boldsymbol{f} f、 o \boldsymbol{o} o和 c ~ \tilde{\boldsymbol{c}} c~,因此需要各自生成四个不同的dropout mask
PeepholeLSTMCell
只是改写了LSTMCell
内部变量的计算逻辑,参见在TF1.x部分的介绍
StackedRNNCells
与TF1.x中的MultiRNNCell
类似
AbstractRNNCell
纯抽象类,类似TF1的RNNCell
,如果用户自己实现一个RNNCell
,需要 可以继承于它。不过有趣的是内置的三种RNN
实现所使用的Cell:SimpleRNNCell
、GRUCell
、LSTMCell
均直接继承自Layer
RNN相关
tf.keras.layers.RNN
所有后续RNN相关类的基类,承担TF1.x中static_rnn
和dynamic_rnn
的双重功能,主要逻辑分别集中在初始化函数__init__
、build
和call
中(__call__
也有一些逻辑,但是只针对某些特殊情况)
RNN
在初始化时传入的参数个人感觉相对来讲不如1.x直观。其允许传入的参数包括
cell
:一种RNNCell的对象,也可以是列表或元组。当传入的参数为列表或元组时,会打包组合为StackedRNNCells
类对象return_sequences
:默认RNN只返回最后一个时间步的输出。当此参数设为True
时,返回每个时间步的输出return_state
:当此参数设为