MASS的功能
MASS的全名叫Masked Sequence to Sequence Pre-training for Language Generation,这个其实已经隐含了和GPT及BERT(前文有介绍)的关系: "Masked"对应了BERT原创的"Masked LM"算法,"Language Generation"指的是GPT用到的Language Modeling。
而"Sequence to Sequence"算是点明了主题:这是一个Seq2Seq的框架。
做Seq2Seq这么一个框架可以说是针对了GPT和BERT的痛点:在序列到序列(Seq2Seq)的语言生成任务中,二者都需要分别训练编码器和解码器,这导致在Seq2Seq任务中效果最好的的编码器-注意力-解码器结构没有被联合训练。因此GPT和BERT在这类任务中只能达到次优效果。
MASS非常巧妙的解决了这个问题,它利用Seq2Seq的结构同时把GPT和BERT两种不同的pretrain方法容纳在一个计算框架之下。这样做的好处有两方面:
- GPT和BERT可以提供强大的pretrain的模型,这有利于下游的transfer learning的任务。
- Seq2Seq保证了高质量的完成语言生成任务。
粗略的用一句话表示:MASS是利用GPT+BERT预训练出的Seq2Seq模型。
GPT和BERT
最好还是先介绍一下2018的双子星GPT和BERT。放在一起比较一下吧。
比较
BERT原论文反复引用了GPT,应该说BERT从GPT中吸取了很多营养。它们最大的共同点有两点:其一是都用transformer。因为二者的成功,相信transformer在NLP领域会取代(bi)LSTM成为最受欢迎的模型结构。LSTM因其在序列结构上的优点(包含long distance dependency的提取能力,以及对有序性的自动包含等等),transformer通过(multi-head)self-attention+positional embedding全部继承了下来。而且因为其摒弃了LSTM的序列结构,并没有并行能力上的限制。
去年的其他一些同一方向上的好工作比如ELMo(前文介绍)和ULMFiT(介绍)利用了(bi)LSTM,由于GPT和BERT的背书,今后的工作中恐怕transformer会成为当仁不让的主流模型。
第二个共同点是大数据加超大模型。具体的数据不再赘述了,原论文里都有。我觉得能够使大数据加大模型训练成为可能,得益于两点,一是计算能力的不断提升,另一个在于足够复杂的模型(特别是transformer)能够消化巨大的数据。
再看一下不同的地方。最大的不同点应该是pretrain的方法。GPT用的是LM, 而BERT用的是Masked LM + Next Sentence。这个不同直接导致在做下游的任务时,GPT需要针对句子对类型的任务(句子相似性,QA等等)做多次编码,因为它的预训练方式导致它不能理解多个句子并存作为输入的情况。对照看一下GPT论文中的配图:
BERT因为它的Next Sentence的预训练方法,它能够理解作为输入的句子对。原论文关于输入的图示:
它的jointly-trainining得到的Segment Embeddings以及token [SEP]都能够帮助模型理解和区别不同的句子。所以在针对下游任务时,BERT可以自然的处理句子对类型。如下图:
简单说,对BERT来讲,一次编码就可以解决句子对问题。
痛点
有很多文本生成类的任务比如machine translation,text summarization,conversational response generation都依赖一个Sequence to Sequence的框架。对于GPT和BERT这样的结构来讲,充其量只能做一个编码器或者解码器,不能支撑一个Seq2Seq的结构。除非做两次预训练,一次做编码器,另一次做解码器。但这样做编码-注意力-解码的联合训练机制就不存在了。MASS就是针对这个问题而提出的方案。
MASS模型
模型解释
MASS希望能够兼顾两点:
- 仍然采取文本生成类任务表现最优秀的编码-注意力-解码模型。
- 为了在少样本甚至零样本的任务中取得好成绩,也为了表现出很好的迁移学习的能力,同时容纳GPT和 BERT的预训练方式。
一图顶万言:
这是MASS的模型框架。Encoder和Decoder都利用transformer作为基础模型。Encoder端有连续的几个"_"代表被masked输入token [M]。直觉上看,这有点像BERT里的Masked LM, 因为整体来讲,在Decoder那一段我们希望利用两边的信息预测被Masked的部分(即图中的利用x1,x2和x7,x8来预测x3,x4,x5,x6)。唯一的区别在于BERT仅仅是利用两端的信息预测一个Masked token,这里是在预测一连串的tokens。
注意,在这个框架中“两边的信息”来自于Encoder,而不是Decoder。因为我们看到在Decoder那一端x1,x2,x7,x8都是Masked。这样做的好处体现在编码和解码两个方面:
- MASS迫使Encode