神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览
文章目录
与TF1.x的实现思路不同,在TF2.0中,RNN已经不再是个函数,而是一个封装好的类,各种RNN和RNNCell与顶层抽象类
Layer
的关系也更加紧凑(需要说明的是说
Layer
顶层并非说它直接继承自
object
,而是从……功能的角度,我觉得可以这么说。真实实现里的继承关系是
Layer --> Module --> AutoTrackable --> Trackable --> object
)。但是另一方面,感觉新的版本里各个类的关系稍微有些杂乱,不知道后面会不会进一步重构。TF2.0的RNN相关各类关系大致如下图所示
相关基类
tf.keras.layers.Layer
与TF1.14的实现基本相同,不再赘述
recurrent.DropoutRNNCellMixin
与之类似的类在TF1.x中以tf.nn.rnn_cell.DropoutWrapper
形式出现,但当时考虑到还没涉及到RNN的dropout就没有引入,没想到在这里还是要说一说。TF2的实现比TF1的实现要简单一些,这个类只是维护两个dropout mask,一个是用于对输入的mask,一个用于对传递状态的mask(严格说是四个,在另一个维度上还考虑是对静态图的mask还是对eager模式的mask)。实现保证mask只被创建一次,因此每个batch使用的mask都相同
RNNCell相关
无论是官方给出的文本分类教程,还是我自己从TF1.x改的用更底层API实现的代码,实际上都没有用到Cell相关的对象。但是为了完整起见(毕竟暴露的LSTM
类背后还需要LSTMCell
类对象作为自己的成员变量),这里还是稍作介绍
LSTMCell
本文以LSTM为主,因此先从LSTMCell
说起。与TF1.x不同,在2.x版本里,LSTMCell
允许传入一个implement
参数,默认为1,标记LSTM各门和输出、状态的计算方式。当取默认的1时,计算方式更像是论文中的方式,逐个计算各个门的结果;而如果设为2,则使用TF1.x中组合成矩阵一并计算的方式。此外,由于LSTMCell
还继承了前述DropoutRNNCellMixin
接口,因此可以在call
里对输入和上一时间步传来的状态做dropout。注意由于LSTM有四个内部变量 i \boldsymbol{i} i、 f \boldsymbol{f} f、 o \boldsymbol{o} o和 c ~ \tilde{\boldsymbol{c}} c~