Seq2Seq之双向解码机制 | 附开源实现

本文详细介绍了Seq2Seq模型中的双向解码机制,以提高文本生成的质量,特别是处理长文本时。该机制源于Synchronous Bidirectional Neural Machine Translation,通过维护两个解码器,分别从左到右和从右到左进行解码,利用Attention机制消除解码的不对称性。在训练和预测过程中,双向解码面临信息泄漏的问题,但可以通过双向束搜索等策略进行优化。作者提供了Keras实现的参考代码,并讨论了双向解码在句子首尾生成质量提高的同时,对中间部分生成质量的影响,以及其在概率模型和信息泄漏方面的挑战。
摘要由CSDN通过智能技术生成

640?


作者丨苏剑林

单位丨追一科技

研究方向丨NLP,神经网络

个人主页丨kexue.fm


在文章玩转Keras之Seq2Seq自动生成标题中我们已经基本探讨过 Seq2Seq,并且给出了参考的 Keras 实现。 


本文则将这个 Seq2Seq 再往前推一步,引入双向的解码机制,它在一定程度上能提高生成文本的质量(尤其是生成较长文本时)。本文所介绍的双向解码机制参考自 Synchronous Bidirectional Neural Machine Translation,最后笔者也是用 Keras 实现的。


640?wx_fmt=png


640?wx_fmt=png


背景介绍


研究过 Seq2Seq 的读者都知道,常见的 Seq2Seq 的解码过程是从左往右逐字(词)生成的,即根据 encoder 的结果先生成第一个字;然后根据 encoder 的结果以及已经生成的第一个字,来去生成第二个字;再根据 encoder 的结果和前两个字,来生成第三个词;依此类推。总的来说,就是在建模如下概率分解。


640?wx_fmt=png


当然,也可以从右往左生成,也就是先生成倒数第一个字,再生成倒数第二个字、倒数第三个字,等等。问题是,不管从哪个方向生成,都会有方向性倾斜的问题。比如,从左往右生成的话,前几个字的生成准确率肯定会比后几个字要高,反之亦然。在 Synchronous Bidirectional Neural Machine Translation 给出了如下的在机器翻译任务上的统计结果:


640?wx_fmt=png


L2R 和 R2L 分别是指从左往右和从右往左的解码生成。从表中我们可以看到,如果从左往右解码,那么前四个 token 的准确率有 40% 左右,但是最后 4 个 token 的准确率只有 35%;反过来也差不多。这就反映了解码的不对称性。 


为了消除这种不对称性,Synchronous Bidirectional Neural Machine Translation 提出了一个双向解码机制,它维护两个方向的解码器,然后通过 Attention 来进一步对齐生成。


双向解码


虽然本文参考自 Synchronous Bidirectional Neural Machine Translation,但我没有完全精读原文,我只是凭自己的直觉粗读了原文,大致理解了原理之后自己实现的模型,所以并不保证跟原文完全一致。此外,这篇论文并不是第一篇做双向解码生成的论文,但它是我看到的双向解码的第一篇论文,所以我就只实现了它,并没有跟其他相关论文进行对比。 


基本思路


既然叫双向“解码”,那么改动就只是在 decoder 那里,而不涉及到 encoder,所以下面的介绍中也只侧重描述 decoder 部分。还有,要注意的是双向解码只是一个策略,而下面只是一种参考实现,并不是标准的、唯一的,这就好比我们说的 Seq2Seq 也只是序列到序列生成模型的泛指,具体 encoder 和 decoder 怎么设计,有很多可调整的地方。 


首先,给出一个简单的示意动图,来演示双向解码机制的设计和交互过程:


 Seq2Seq的双向解码机制图示


如图所示,双向解码基本上可以看成是两个不同方向的解码模块共存,为了便于描述,我们将上方称为 L2R 模块,而下方称为 R2L 模块。开始情况下,大家都输入一个起始标记(上图中的 S),然后 L2R 模块负责预测第一个字,而 R2L 模块负责预测最后一个字。


接着,将第一个字(以及历史信息)传入到 L2R 模块中,来预测第二个字,为了预测第二个字,除了用到 L2R 模块本身的编码外,还用到 R2L 模块已有的编码结果;反之,将最后一个字(以及历史信息)传入到 R2L 模块,再加上 L2R 模块已有的编码信息,来预测倒数第二个字;依此类推,直到出现了结束标记(上图中的E)。 


数学描述


换句话说,每个模块预测每一个字时,除了用到模块内部的信息外,还用到另一模块已经编码好的信息序列,而这个“用”是通过 Attention 来实现的。用公式来说,假设当前情况下 L2R 模块要预测第 n 个字,以及 R2L 模块要预测倒数第 n 个字。假设经过若干层编码后,得到的 R2L 向量序列(对应图中左上方的第二行)为: 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值