原文: https://zhuanlan.zhihu.com/p/48959800
本文介绍15年发表在NIPS上的一篇文章:Pointer Networks[1],以及后续应用了Pointer Networks的三篇文章:Get To The Point: Summarization with Pointer-Generator Networks[2]、Incorporating Copying Mechanism in Sequence-to-Sequence Learning [3]和Multi-Source Pointer Network for Product Title Summarization[4]。
一、从Sequence2Sequence说起
Sequence2Sequence(简称seq2seq)模型是RNN的一个重要的应用场景,顾名思义,它实现了把一个序列转换成另外一个序列的功能,并且不要求输入序列和输出序列等长。比较典型的如机器翻译,一个英语句子“Who are you”和它对应的中文句子“你是谁”是两个不同的序列,seq2seq模型要做的就是把这样的序列对应起来。
由于类似语言这样的序列都存在时序关系,而RNN天生便适合处理具有时序关系的序列,因此seq2seq模型往往使用RNN来构建,如LSTM和GRU。具体结构见Sequence to Sequence Learning with Neural Networks[5]这篇文章提供的模型结构图:
图1:Seq2seq模型结构
在这幅图中,模型把序列“ABC”转换成了序列“WXYZ”。分析其结构,我们可以把seq2seq模型分为encoder和decoder两个部分。encoder部分接收“ABC”作为输入,然后将这个序列转换成为一个中间向量C,向量C可以认为是对输入序列的一种理解和表示形式。然后decoder部分把中间向量C作为自己的输入,通过解码操作得到输出序列“WXYZ”。
后来,Attention Mechanism[6]的加入使得seq2seq模型的性能大幅提升,从而大放异彩。那么Attention Mechanism做了些什么事呢?一言以蔽之,Attention Mechanism的作用就是将encoder的隐状态按照一定权重加和之后拼接(或者直接加和)到decoder的隐状态上,以此作为额外信息,起到所谓“软对齐”的作用,并且提高了整个模型的预测准确度。简单举个例子,在机器翻译中一直存在对齐的问题,也就是说源语言的某个单词应该和目标语言的哪个单词对应,如“Who are you”对应“你是谁”,如果我们简单地按照顺序进行匹配的话会发现单词的语义并不对应,显然“who”不能被翻译为“你”。而Attention Mechanism非常好地解决了这个问题。如前所述,Attention Mechanism会给输入序列的每一个元素分配一个权重,如在预测“你”这个字的时候输入序列中的“you”这个词的权重最大,这样模型就知道“你”是和“you”对应的,从而实现了软对齐。
二、Pointer Networks
背景讲完,我们就可以正式进入Pointer Networks这部分了。为什么在讨论Pointer Networks之前要先说seq2seq以及Attention Mechanism呢,因为Pointer Networks正是通过对Attention Mechanism的简化而得到的。
作者开篇就提到,传统的seq2seq模型是无法解决输出序列的词汇表会随着输入序列长度的改变而改变的问题的,如寻找凸包等。因为对于这类问题,输出往往是输入集合的子集。基于这种特点,作者考虑能不能找到一种结构类似编程语言中的指针,每个指针对应输入序列的一个元素,从而我们可以直接操作输入序列而不需要特意设定输出词汇表。作者给出的答案是指针网络(Pointer Networks)。我们来看作者给出的一个例子:
图2:Pointer Networks实例:寻找凸包
这个图的例子是给定p1到p4四个二维点的坐标,要求找到一个凸包。显然答案是p1->p4->p2->p1。图a是传统seq2seq模型的做法,就是把四个点的坐标作为输入序列输入进去,然后提供一个词汇表:[start, 1, 2, 3, 4, end],最后依据词汇表预测出序列[start, 1, 4, 2, 1, end],缺点作者也提到过了,对于图a的传统seq2seq模型来说,它的输出词汇表已经限定,当输入序列的长度变化的时候(如变为10个点)它根本无法预测大于4的数字。图b是作者提出的Pointer Networks,它预测的时候每一步都找当前输入序列中权重最大的那个元素,而由于输出序列完全来自输入序列,它可以适应输入序列的长度变化。
那么Pointer Networks具体是怎样实现的呢?
我们首先来看传统注意力机制的公式:
图3:传统注意力机制公式
其中是encoder的隐状态,而是decoder的隐状态,v,W1,W2都是可学习的参数,在得到之后对其执行softmax操