XLNet论文解读+部分代码解读

XLNet Generalized Autoregressive Pretraining

Publisher:
作者: Zhilin Yang, Zihang Dai等
单位:Carnegie Mellon University, Google Brain
论文链接

XLNet:运行机制及和Bert的异同比较

飞跃芝麻街:XLNet 详解

XLnet:比Bert更强大的预训练模型

【BERT 系列 2】之 XLNet

中文XLNet预训练模型

transformer-XL相对位置编码示意图

github上面别人提出针对双向transformer-XL的疑问

1.Motivation

  作者认为,Bert这种基于自编码的具有双向建模能力的模型性能比基于自回归建模的语言模型的性能要好。但是Bert因为采用了Mask的训练方式, 忽略了被Mask掉词之间的依赖关系;同时因为Bert是基于自编码的,所以和基于自回归的模型相比较的,在面对生成任务的时候有缺陷;而且因为Bert是基于transformer的,所以在序列长度方面有限制。所以作者就希望可以可以融合自编码和自回归的优点,然后设计出来一个模型。

2.自回归语言模型和自编码语言模型

2.1 自回归语言模型

  其实就是指RNN这一类的模型,这一类的模型的优化目标是最大化概率 p ( x ) = ∑ t = 1 T p ( x t ∣ X < t ) p(x)=\sum_{t=1}^{T}p(x_t | X_{< t}) p(x)=t=1Tp(xtX<t) 或者最大化概率 p ( x ) = ∑ 1 t = T p ( x t ∣ X < t ) p(x)=\sum_{1}^{t=T}p(x_t | X_{< t}) p(x)=1t=Tp(xtX<t) 。其中 X = ( x 1 , . . . , x T ) X=(x_1,...,x_T) X=(x1,...,xT). 自回归语言模型有以下的优点:(其实就是rnn的优点)

  • 自回归语言模型的模型符合生成任务的需求,就是那种一个一个的生成我们需要的字符。类似人写字一样一个一个的写出来.

  • 同时自回归语言模型可以学习要预测词之间的的关系,因为被预测的词是根据上一个词预测出来的。

然后也有一些缺点:

  • 但是自回归语言模型难以并行计算,以及无法提供一些语言理解任务中需要的双向上下文信息。作者人为, 虽然有双线RNN的模型, 但是这些模型本身都是单向, 然后拼接成的双向.

2.2 自编码语言模型

  指的是以Bert为代表这类语言模型,这类模型的特点是将输入的数据破坏掉,然后通过剩下的输入数据,再次重建出来被破坏掉的数据,即优化的是 p ( x ∣ x ~ ) p(x | \tilde{x}) p(xx~) 。然后自编码语言么的优点是:(就是Bert的优点, 其实就是和上面的RNN反着来吧)

  • 可以更好的提供上下文信息和提供并行计算

缺点是:

  • 从某些角度说,是在模型的输入端加入噪声,然后模型进行去除噪声。因为Bert引入了[MASK]符号,但是这个符号因为不出现在微调阶段,就导致了预训练和微调之间的差距。而且作者人为Bert这种mask操作导致模型学习不到被mask掉词之间的关系.

  • 和预测的时候生成任务的生成不匹配,导致生成任务效果较差.

3.XLNet的主要改进

3.1 Permutation Language Modeling

  为了引入自回归模型的优点,同时可以看到上下文,论文中提出了排列语言模型。例如序列[1,2,3,4],如果是为了预测3,那么我们怎么样才能使用自回归的方式让3看到上下文呢?

  如图3-1,打乱序列的顺序之后输入到模型里面,就可以发现,当我们需要预测3的时候,我们能看到的只能是3前面的单词,如果打乱序列顺序之后,我们可以看到第一行3可以看到1,2;第二行3可以看到2,4;以此类推,那么要预测的词就可以看到上下文了。

1 , 2 , 3 , 4 2 , 4 , 3 , 1 1 , 4 , 3 , 2 图 3 − 1 \begin{aligned} &1, 2, 3, 4 \\ &2, 4, 3, 1 \\ &1, 4, 3, 2 \\ &图3-1 \end{aligned} 1,2,3,42,4,3,11,4,3,231
在这里插入图片描述
图 3-2: 论文提供的打乱顺序的输入示意图, 图中表示的都是3这个位置的单词在不同输入的顺序下面可以看到的词. 因为采用了transformer-XL, 所以前面会有一个mem的记忆模块.

3.2 双流自我注意力结构

  双流自我注意力结构应该说是对于Permutation Language Modeling的实现方式。

3.2.1 attention mask

  首先是加入的attention masks,因为XLNet为了保证预训练的输入和之后的微调的时候保证一致,不可能直接打乱序列的输入顺序。所以模型的输入还是正常的序列顺序。为了实现打乱顺序的需要,模型在进行attention的时候,进行了mask操作。

在这里插入图片描述
图 3-3

  对于输入顺序如果是3-2-4-1的情况下,目前只看content stream的mask图。对于第一行,代表的是1,因为打乱顺序之后,1相当于是最后输入,那么1可以看到所有的序列;对于第二行对应的是2,2相当于第二个输入进去的,所以2能看到的是3和2,那么对应的mask区域只有第二个和第三个可以不被mask掉。以此类推。

  但是这种处理方法,带来了一个问题,例如依旧是3-2-4-1的输入顺序,在预测单词4的时候,模型用这种mask方式可以看到4自己的信息;如果把4也mask掉,那么在预测1的时候又看不到4的信息了。同时为了解决Bert使用[MASK]代替被屏蔽单词的问题,所以作者设计了一个新的结构去解决这个问题, 也就是加入了Query stream的另外一个流的自我注意力结构.

3.2.2 其余的双流操作

  双流自我注意力结构分为2部分,分别是内容流和查询流。内容流,则是正常的transformer-XL的计算方式(和transformer-XL其实是略有不同的, 3.2.3会详细的讲和transformer-XL计算的差异),使用上面介绍的mask方法。查询流中,attention中的Q只包含了输入的位置信息, 而K,V则包含了内容信息,但是K,V包含的内容信息只包括输入序列的位置t的前面的1-t个单词的内容信息,并且不包含第t个单词,所以和content的mask相比,对角线上的都mask掉了。
在这里插入图片描述
图 3-4

  XLNet和Bert类似,采用了类似的“掩盖”一部分输入的序列,然后让模型去预测。XLNet每次掩盖的时候,选择的都是打乱顺序之后的序列的最后面的一部分,这样也和自回归模型的模式类似。掩盖的比例,根据论文的实验,选择的是1/7-1/6,也就是14.28%-16.67%。

在这里插入图片描述
图 3-5

  XLNet采用的是transformer-XL的模型结构,对于位置编码,直接采用transformer-XL的相对位置编码方式。这里进行的修改是相对段编码。

  Bert采用的是绝对位置段编码,但是因为XLNet采用的是transformer-XL,所以需要使用上次的记忆数据,这里也采用了相对句子段编码。对于i计算j的注意力值的时候,如果ij来自同一段,那么采用s+,否则采用s-,然后计算出来的aij直接加到正常的注意力值里面即可。有两个优点:

  • 增加了模型的泛化性

  • 保证了微调的时候遇到多个句子依旧可以正常使用

a i j = ( q i + b ) T s i j , 其 中 s i j = { s + , i j 来 自 同 一 段 s − , i j 来 自 不 同 段 a_{ij} = (q_i + b)^T s_{ij}, 其中s_{ij} = \left\{\begin{matrix} s_+ &, ij来自同一段 \\ s_- &, ij来自不同段 \end{matrix}\right. aij=(qi+b)Tsij,sij={s+s,ij,ij

3.2.3 XLNet的双向transformer-XL

  这里针对的是内容流, 因为双流的查询流的代码没看(因为没机器可以跑预训练, 于是放弃了). 其实这里的双向transformer-XL和单向的transformer-XL的实现的差别主要还是在计算下面的公式的b和d上面:

A i , j r e l = E x i T W q T W k , E E x j ⏟ a + E x i T W q T W k , R R i − j ⏟ b + u T W k , E E x j ⏟ c + v T W k , R R i − j ⏟ d A^{rel}_{i,j} = \underbrace{E_{x_i}^T W_q^T W_{k,E} E_{x_j}}_{a} + \underbrace{E_{x_i}^T W_q^T W_{k,R} \color{blue} R_{i-j}}_{b} + \underbrace{{\color{red} u^T} W_{k,E} E_{x_j}}_{c} + \underbrace{{\color{red}v^T} W_{k,R} \color{blue} R_{i-j}}_{d} Ai,jrel=a ExiTWqTWk,EExj+b ExiTWqTWk,RRij+c uTWk,EExj+d vTWk,RRij

可以去看一下XLNet中生产位置信息的实现代码:(下面截取的是huggingface的XLNet的pytorch版本的实现代码, 和原版的tensorflow的基本完全一样)

首先是生成 R i − j R_{i-j} Rij 的部分:

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        """create relative positional encoding."""
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == 'bi':
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == 'uni':
            # beg, end = klen - 1, -1
            beg, end = klen, -1
        else:
            raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))

        if self.bi_data:
            fwd_pos_seq = torch.arange(beg, end, -1.0, dtype=torch.float)
            bwd_pos_seq = torch.arange(-beg, -end, 1.0, dtype=torch.float)

            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
                bwd_pos_seq = bwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)

            if bsz is not None:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz//2)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz//2)
            else:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)

            pos_emb = torch.cat([fwd_pos_emb, bwd_pos_emb], dim=1)
        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(next(self.parameters()))
        return pos_emb

调用relative_positional_encoding来生成 R i − j R_{i-j} Rij, 上面的代码写的比较复杂, 但是实际上我们需要关注的代码只有下面这么多:

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = torch.einsum('i,d->id', pos_seq, inv_freq)
        pos_emb = torch.cat([torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)], dim=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = pos_emb.expand(-1, bsz, -1)

        return pos_emb

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        """create relative positional encoding."""
        freq_seq = torch.arange(0, self.d_model, 2.0, dtype=torch.float)
        inv_freq = 1 / torch.pow(10000, (freq_seq / self.d_model))

        if self.attn_type == 'bi':
            # beg, end = klen - 1, -qlen
            beg, end = klen, -qlen
        elif self.attn_type == 'uni':
            ...
        else:
            raise ValueError('Unknown `attn_type` {}.'.format(self.attn_type))

        if self.bi_data:
            ...
        else:
            fwd_pos_seq = torch.arange(beg, end, -1.0)
            if self.clamp_len > 0:
                ...
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)

        pos_emb = pos_emb.to(next(self.parameters()))
        return pos_emb

仔细看代码的话, 我们可以发现, 这次生成的位置的长度, 实际上是: mem_len + input_len + input_len, 和以前的transformer-XL相比较的话, 这里增加的一个 input_len长度的位置, 其实这里增加的就是双向的部分, 因为原来的单向transformer-XL只有一个mem_len + input_len的长度, 这里增加的input_len长度就是反向的位置信息. 然后剩下的关于位置信息的计算, 比较不一样的应该就是之后的截取部分了, 具体的代码位置是在类XLNetRelativeAttention中的rel_attn_core函数和``函数中:

    def rel_attn_core(self, q_head, k_head_h, v_head_h, k_head_r, seg_mat=None, attn_mask=None, head_mask=None):
        """Core relative positional attention operations."""

        ...

        # position based attention score
        bd = torch.einsum('ibnd,jbnd->ijbn', q_head + self.r_r_bias, k_head_r)
        bd = self.rel_shift(bd, klen=ac.shape[1])

        ...

        return attn_vec
    
        @staticmethod
    def rel_shift(x, klen=-1):
        """perform relative shift to form the relative attention score."""
        x_size = x.shape

        x = x.reshape(x_size[1], x_size[0], x_size[2], x_size[3])
        x = x[1:, ...]
        x = x.reshape(x_size[0], x_size[1] - 1, x_size[2], x_size[3])
        # x = x[:, 0:klen, :, :]
        x = torch.index_select(x, 1, torch.arange(klen, device=x.device, dtype=torch.long))

        return x

怎么说呢, 这里没有进行pad操作, 直接就截取了, 下面的示意图大概可以说明这个过程:

在这里插入图片描述
图3-6

图3-6基本就是XLNet中的双向transformer-XL的位置信息生成的过程, 怎么说呢? 这里假设是输入了3个单词, 然后记忆模块可以记忆2个单词, 那么XLNet中生成的input_len + mem_len + input_len位置中, 前面的浅蓝色的是从右到左的位置信息, 中间的是来自上次的记忆信息, 右边的深蓝色是从左到右的位置信息(其实严格的来说, 并没有左右的区分, 只是为了实现双向, 所以就弄了两个, 反正论文里面没写, 我就这样比喻一下). 然后经过这一系列的操作, 我们可以看到最终结果的第一行中的位置信息里面包含所有的mem部分的位置以及剩下的所有的输入字符的位置, 剩下的每一行都是2个绿色的加3个蓝色的构成的, 当然有浅蓝和深蓝, 但是确实都包含了所有的字符的位置信息, 这里我也比较迷, 最后搞成了这样的结果, 但是模型事实的运行结果证明这样是可行的.(或许是我理解错了)

4.于Bert的对比

  作者认为,Bert无法学习到被mask掉部分的词之间的信息,作者举例,对于句子“New York is a city”,预测的目标是“New York”,那么Bert和XLNet的优化目标分别是:

ξ B e r t = l o g   p ( N e w ∣ i s a c i t y ) + l o g   p ( Y o r k ∣ i s a c i t y ) ξ X L N e t = l o g   p ( N e w ∣ i s a c i t y ) + l o g   p ( Y o r k ∣ N e w , i s a c i t y ) \begin{aligned} & \xi_{Bert} = log \ p(New | is a city) + log \ p(York | is a city) \\ & \xi_{XLNet} = log \ p(New | is a city) + log \ p(York | New, is a city) \end{aligned} ξBert=log p(Newisacity)+log p(Yorkisacity)ξXLNet=log p(Newisacity)+log p(YorkNew,isacity)

  根据优化目标可以看到,XLNet在预测出来New,会在预测York的时候把New加入到先决条件中。这样,被mask掉的词也可以学习它们之间的关系。

5.实验对比

5.1 长文档阅读理解

  RACE数据集是一个针对中国中学生和高中生的英语考试的数据集,数据集包含近10万个问题,是目前最难的阅读理解数据集,且数据集中的段落的平均长度在300个单词之上,比一般的阅读理解的数据集的长度都长的多。

  XLNet对于Bert提升了大概接近10%左右,根据后面的一些实验分析,这里的提升除了加入了PLM,更多的可能是因为使用了transformer-XL。

在这里插入图片描述
图 5-1

  SQuAD数据集和RACE类似,都是长文档级别的阅读理解数据集,这里的效果提升针对Bert而言提升也比较明显。

在这里插入图片描述
图 5-2

5.2 消融实验

  参考https://zhuanlan.zhihu.com/p/70257427,论文中使用和Bert相同的训练量,训练了一个XLNet-base模型用于和Bert进行对比。

在这里插入图片描述
图5-3

  首先看DAE+transformer-XL的实验结果,这里相当于Bert中的transformer替换成了transformer-XL,是研究长文档因素造成的影响。RACE和SQuAD2.0都是长文档的阅读理解,分数提升1和3个点,但是MNLI和SST-2都是句对分类任务,提升就不明显了。说明transformer-XL带来了长文档的效果提升。

  之后再参考XLNet-Base的效果,这里体现的是PLM带来的提升,可以看到四个数据集都有1个点左右的提升,说明PLM是可以给模型带来收益的。

  除此之外,根据网上别人的分析,根据前面XLNet-large的得分情况,再对比消融研究中的XLNet-Base的得分情况,可以大概得出训练数据量的提升(接近10倍Bert训练量)给XLNet的模型在长文本阅读理解上的提升占到30%左右。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值