神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览

这篇博客详细介绍了在TensorFlow 2.0中RNN的相关实现,包括RNNCell的基础类和LSTM的使用。文章探讨了RNNCell的 dropout 实现、LSTM的门机制以及RNN类的参数和功能。还提到了双向RNN的实现,并指出TF2.0中速度优化的CudnnRNN。
摘要由CSDN通过智能技术生成

神经翻译笔记4扩展a第二部分. RNN在TF2.0中的实现方法略览


与TF1.x的实现思路不同,在TF2.0中,RNN已经不再是个函数,而是一个封装好的类,各种RNN和RNNCell与顶层抽象类 Layer的关系也更加紧凑(需要说明的是说 Layer顶层并非说它直接继承自 object,而是从……功能的角度,我觉得可以这么说。真实实现里的继承关系是 Layer --> Module --> AutoTrackable --> Trackable --> object)。但是另一方面,感觉新的版本里各个类的关系稍微有些杂乱,不知道后面会不会进一步重构。TF2.0的RNN相关各类关系大致如下图所示

在这里插入图片描述

相关基类

tf.keras.layers.Layer

与TF1.14的实现基本相同,不再赘述

recurrent.DropoutRNNCellMixin

与之类似的类在TF1.x中以tf.nn.rnn_cell.DropoutWrapper形式出现,但当时考虑到还没涉及到RNN的dropout就没有引入,没想到在这里还是要说一说。TF2的实现比TF1的实现要简单一些,这个类只是维护两个dropout mask,一个是用于对输入的mask,一个用于对传递状态的mask(严格说是四个,在另一个维度上还考虑是对静态图的mask还是对eager模式的mask)。实现保证mask只被创建一次,因此每个batch使用的mask都相同

RNNCell相关

无论是官方给出的文本分类教程,还是我自己从TF1.x改的用更底层API实现的代码,实际上都没有用到Cell相关的对象。但是为了完整起见(毕竟暴露的LSTM类背后还需要LSTMCell类对象作为自己的成员变量),这里还是稍作介绍

LSTMCell

本文以LSTM为主,因此先从LSTMCell说起。与TF1.x不同,在2.x版本里,LSTMCell允许传入一个implement参数,默认为1,标记LSTM各门和输出、状态的计算方式。当取默认的1时,计算方式更像是论文中的方式,逐个计算各个门的结果;而如果设为2,则使用TF1.x中组合成矩阵一并计算的方式。此外,由于LSTMCell还继承了前述DropoutRNNCellMixin接口,因此可以在call里对输入和上一时间步传来的状态做dropout。注意由于LSTM有四个内部变量 i \boldsymbol{i} i f \boldsymbol{f} f o \boldsymbol{o} o c ~ \tilde{\boldsymbol{c}} c~,因此需要各自生成四个不同的dropout mask

PeepholeLSTMCell

只是改写了LSTMCell内部变量的计算逻辑,参见在TF1.x部分的介绍

StackedRNNCells

与TF1.x中的MultiRNNCell类似

AbstractRNNCell

纯抽象类,类似TF1的RNNCell,如果用户自己实现一个RNNCell需要 可以继承于它。不过有趣的是内置的三种RNN实现所使用的Cell:SimpleRNNCellGRUCellLSTMCell均直接继承自Layer

RNN相关

tf.keras.layers.RNN

所有后续RNN相关类的基类,承担TF1.x中static_rnndynamic_rnn的双重功能,主要逻辑分别集中在初始化函数__init__buildcall中(__call__也有一些逻辑,但是只针对某些特殊情况)

RNN在初始化时传入的参数个人感觉相对来讲不如1.x直观。其允许传入的参数包括

  • cell:一种RNNCell的对象,也可以是列表或元组。当传入的参数为列表或元组时,会打包组合为StackedRNNCells类对象
  • return_sequences:默认RNN只返回最后一个时间步的输出。当此参数设为True时,返回每个时间步的输出
  • return_state:当此参数设为
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值