神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览

神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览


与TF1.x的实现思路不同,在TF2.0中,RNN已经不再是个函数,而是一个封装好的类,各种RNN和RNNCell与顶层抽象类 Layer的关系也更加紧凑(需要说明的是说 Layer顶层并非说它直接继承自 object,而是从……功能的角度,我觉得可以这么说。真实实现里的继承关系是 Layer --> Module --> AutoTrackable --> Trackable --> object)。但是另一方面,感觉新的版本里各个类的关系稍微有些杂乱,不知道后面会不会进一步重构。TF2.0的RNN相关各类关系大致如下图所示

在这里插入图片描述

相关基类

tf.keras.layers.Layer

与TF1.14的实现基本相同,不再赘述

recurrent.DropoutRNNCellMixin

与之类似的类在TF1.x中以tf.nn.rnn_cell.DropoutWrapper形式出现,但当时考虑到还没涉及到RNN的dropout就没有引入,没想到在这里还是要说一说。TF2的实现比TF1的实现要简单一些,这个类只是维护两个dropout mask,一个是用于对输入的mask,一个用于对传递状态的mask(严格说是四个,在另一个维度上还考虑是对静态图的mask还是对eager模式的mask)。实现保证mask只被创建一次,因此每个batch使用的mask都相同

RNNCell相关

无论是官方给出的文本分类教程,还是我自己从TF1.x改的用更底层API实现的代码,实际上都没有用到Cell相关的对象。但是为了完整起见(毕竟暴露的LSTM类背后还需要LSTMCell类对象作为自己的成员变量),这里还是稍作介绍

LSTMCell

本文以LSTM为主,因此先从LSTMCell说起。与TF1.x不同,在2.x版本里,LSTMCell允许传入一个implement参数,默认为1,标记LSTM各门和输出、状态的计算方式。当取默认的1时,计算方式更像是论文中的方式,逐个计算各个门的结果;而如果设为2,则使用TF1.x中组合成矩阵一并计算的方式。此外,由于LSTMCell还继承了前述DropoutRNNCellMixin接口,因此可以在call里对输入和上一时间步传来的状态做dropout。注意由于LSTM有四个内部变量 i \boldsymbol{i} i f \boldsymbol{f} f o \boldsymbol{o} o c ~ \tilde{\boldsymbol{c}} c~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值