XLNet:Generalized Autoregressive Pretraining for Language Understanding(2019-6-19)

模型概述

XLNet 中提出了一种比较有意思的观点,将当前预训练模型分为了两类 AR (Auto Regression,自回归) 和 AE (Auto Encoder,自编码器)。 XLNet 将 AR 和 AE 两种方法的优点结合起来,XLNet 使用了 PLM(Permutation Language Model,排列组合语言模型)实现这一目的。

另外,由于使用排列组合模型,使用Transformer的自注意力会导致不知道预测的是哪一个token,XLNet通过双流自注意力机制来实现目标位置感知。

模型优化

排列组合语言模型

自回归语言模型就是根据上文内容预测下一个可能跟随的单词,也就是常说的自左向右的语言模型,或者反过来也行,就是根据下文预测前面的单词。

在这里插入图片描述

自回归语言模型有优点有缺点,缺点是只能利用上文或者下文的信息,不能同时利用上文和下文的信息,它的优点,其实跟下游NLP任务有关,比如生成类NLP任务,比如文本摘要,机器翻译等,在实际生成内容的时候,就是从左向右的,自回归语言模型天然匹配这个过程。

GPT 就是典型的自回归语言模型。ELMO尽管看上去利用了上文,也利用了下文,但是本质上仍然是自回归LM,这个跟模型具体怎么实现有关系。ELMO是做了两个方向(从左到右以及从右到左两个方向的语言模型),但是是分别有两个方向的自回归LM,然后把LSTM的两个方向的隐节点状态拼接到一起,来体现双向语言模型这个事情的。所以其实是两个自回归语言模型的拼接,本质上仍然是自回归语言模型。

自编码语言模型是根据上下文单词来预测上下文中的单词。它能比较自然地融入双向语言模型,同时看到被预测单词的上文和下文。优缺点正好和自回归LM反过来,它能比较自然地融入双向语言模型,同时看到被预测单词的上文和下文,缺点主要在输入侧引入[Mask]标记,导致预训练阶段和Fine-tuning阶段不一致的问题,因为Fine-tuning阶段是看不到[Mask]标记的。

在这里插入图片描述

BERT通过在输入X中随机Mask掉一部分单词,然后根据上下文单词来预测这些被Mask掉的单词。

排列组合语言模型将句子中的单词随机排列,然后采用自回归的方式预测末尾的几个单词。这样一来,在预测单词的时候就可以同时利用该单词双向的信息,并且能学到单词间的依赖。

在这里插入图片描述

XLNet 中通过 Attention Mask 实现 PLM,而无需真正修改句子 token 的顺序。例如原来的句子是 [1,2,3,4],如果随机生成的序列时 [3,2,4,1],则输入到 XLNet 的句子仍然是 [1,2,3,4],但是掩码需要修改成下图。

在这里插入图片描述
图中的掩码矩阵,红色表示不遮掩,白色表示遮掩。第 1 行表示 token 1 的掩码,在序列 3241中处于最后,能够看到前面的324,因此第一行的第2、3、4的圈圈都是红色,表示不遮掩。第2行表示token2的掩码,在序列3241中位于第二个,它只能看到前面的3,因此第二行的第3个圈圈为红色。第3行表示token3的掩码,在序列3241中处于最前面,因为它前面没有任何token,因此它看不到任何的token,所以4个圈圈都应为白色,表示遮掩。

双流自注意力:

XLNet 打乱了句子的顺序,这时在预测的时候 token 的位置信息会非常重要,同时在预测的时候也必须将 token 的内容信息遮掩起来 (否则输入包含了要预测的内容信息,模型就无法学到知识)。也就是说 XLNet 需要看到 token 的位置信息,但是又不能看到 token 的内容信息,因此 XLNet 采用了两个 Stream 实现这一目的:

Query Stream,对于每一个 token,其对应的 Query Stream 只包含了该 token 的位置信息,注意是 token 在原始句子的位置信息,不是重新排列的位置信息。
Content Stream,对于每一个 token,其对应的 Content Stream 包含了该 token 的内容信息。

Query Stream 计算
Query Stream 用 g表示,Content Stream 用 h 表示,使用 Query Stream 对要预测的位置进行预测的时候,Q (Query) 向量是用 g 计算得到的,包含该位置的位置信息,而 K (Key) 和 V (Value) 是用 h 计算的,包含其他 token 的内容信息。下图展示了如何通过当前层的 g 计算下一层 g 的过程,图中的排列是 [3,2,4,1],计算的 token 是 1。

在这里插入图片描述

可以看到在计算 token 1 的 Q 向量时,只使用了 token 1 的 Query Stream g,即模型只得到 token 1 的位置信息。而向量 K,V 使用 token 3, 2, 4 进行计算,所以模型可以得到 token 3, 2, 4 的内容信息。因为 token 1 是排列 [3,2,4,1] 的最后一位。这一个过程的掩码矩阵和上一节的是一样的 ,对角线上都为白色,即遮掩当前预测位置的内容信息 h。

在这里插入图片描述
Content Stream 计算
Content Stream 包含了 token 的内容信息,因为 XLNet 的层数很多,需要将 token 的内容传递到下一层。这一层的 Q, K, V 都是利用 h 计算的。Content Stream 的计算如下图所示。

在这里插入图片描述

可以看到,在计算下一层的 h1时,也会利用 token 1 当前的内容信息,这样就可以将 token 的内容传递到下一层,但是注意 XLNet 在预测时只是用 g (Query Stream)。计算 Content Stream 时候的掩码矩阵如下图所示。

在这里插入图片描述
和 Query Stream 的掩码矩阵区别在于对角线,Content Stream 不遮掩对角线,使得当前 token 的信息可以传递到下一层。

将 Query Stream 和 Content Stream 组合在一起,如下图所示。

在这里插入图片描述

图中最下面的一层是输入层,其中 e(x) 初始为单词的词向量,而 w 初始为一个可训练的向量。代码如下,如果有什么看不懂的地方可以留言。

class TFXLNetRelativeAttention(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        if config.d_model % config.n_head != 0:
            raise ValueError(
                f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
                f"heads ({config.n_head}"
            )

        self.n_head = config.n_head
        self.d_head = config.d_head
        self.d_model = config.d_model
        self.scale = 1 / (config.d_head ** 0.5)
        self.initializer_range = config.initializer_range
        self.output_attentions = config.output_attentions

        self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        self.dropout = layers.Dropout(config.dropout)

    def rel_attn_core(
            self,
            q_head,
            k_head_h,
            v_head_h,
            k_head_r,
            seg_mat,
            attn_mask,
            head_mask,
            output_attentions,
            training=False
    ):
        ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)

        bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
        bd = self.rel_shift(bd, klen=shape_list(ac[1]))

        if seg_mat is None:
            ef = 0
        else:
            ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
            ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)

        attn_score = (ac + bd + ef) * self.scale
        if attn_mask is not None:
            if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
                attn_score = attn_score - 65500 * attn_mask
            else:
                attn_score = attn_score - 1e30 * attn_mask

        attn_prob = stable_softmax(attn_score, axis=1)
        attn_prob = self.dropout(attn_prob, training=training)

        if head_mask is not None:
            attn_prob = attn_prob * head_mask

        attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)

        if output_attentions:
            return attn_vec, attn_prob
        return attn_vec

    def rel_shift(self, x, klen=-1):
        x_size = shape_list(x)
        x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
        x = x[1:, ...]
        x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
        x = x[:, 0:klen, :, :]
        return x
    
    def build(self, input_shape):
        initializer = get_initializer(self.initializer_range)
        self.q = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
        )
        self.k = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
        )
        self.v = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
        )
        self.o = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
        )
        self.r = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
        )
        self.r_r_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
        )
        self.r_s_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
        )
        self.r_w_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
        )
        self.seg_embed = self.add_weight(
            shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
        )
        super().build(input_shape)
        
    def post_attention(self, h, attn_vec, residual=True, training=False):
        # shape: (..., n_head, d_head) - > (..., d_model)
        attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
        attn_out = self.dropout(attn_out, training=training)
        # 残差连接
        if residual:
            attn_out = attn_out + h
        output = self.layer_norm(attn_out)

        return output
    
    def call(
            self,
            h,
            g,
            attn_mask_h,
            attn_mask_g,
            r,
            seg_mat,
            mems,
            target_mapping,
            head_mask,
            output_attentions,
            training=False
    ):
        if g is not None:
            if mems is not None and len(shape_list(mems)) > 1:
                # shape: (mlen+qlen, bsz, d_model)
                cat = tf.concat([mems, h], axis=0)
            else:
                # shape: (qlen, bsz, d_model)
                cat = h
            # h stream
            # [qlen, bsz, n_head, d_head]
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)

            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)

            attn_vec_h = self.rel_attn_core(
                q_head_h,
                k_head_h,
                v_head_h,
                k_head_r,
                seg_mat,
                attn_mask_h,
                head_mask,
                output_attentions,
                training=training
            )
            if output_attentions:
                attn_vec_h, attn_prob_h = attn_vec_h
            output_h = self.post_attention(h, attn_vec_h, training=training)

            # g stream
            q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)

            if target_mapping is not None:
                q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
                attn_vec_g = self.rel_attn_core(
                    q_head_g,
                    k_head_h,
                    v_head_h,
                    k_head_r,
                    seg_mat,
                    attn_mask_g,
                    head_mask,
                    output_attentions,
                    training=training
                )
                if output_attentions:
                    attn_vec_g, attn_prob_g = attn_vec_g
                attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
            else:
                attn_vec_g = self.rel_attn_core(
                    q_head_g,
                    k_head_h,
                    v_head_h,
                    k_head_r,
                    seg_mat,
                    attn_mask_g,
                    head_mask,
                    output_attentions,
                    training=training
                )
                if output_attentions:
                    attn_vec_g, attn_prob_g = attn_vec_g

            output_g = self.post_attention(g, attn_vec_g, training=training)

            if output_attentions:
                attn_prob = attn_prob_h, attn_prob_g
        else:
            if mems is not None and len(shape_list(mems)) > 1:
                cat = tf.concat([mems, h], axis=0)
            else:
                cat = h
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
            attn_vec = self.rel_attn_core(
                q_head_h,
                k_head_h,
                v_head_h,
                k_head_r,
                seg_mat,
                attn_mask_h,
                head_mask,
                output_attentions,
                training=training
            )
            if output_attentions:
                attn_vec, attn_prob = attn_vec

            output_h = self.post_attention(h, attn_vec, training=training)
            output_g = None

        outputs = (output_h, output_g)
        if output_attentions:
            outputs = outputs + (attn_prob, )
        return outputs

XLNet 将句子重新排列,然后根据排列后的顺序使用 AR 方式预测,但是由于句子是随机排列的,会导致优化比较困难且收敛速度慢。因此 XLNet 采用了 Partial Prediction (部分预测) 的方式进行训练,对于排列后的句子,只预测句子末尾的 1/K 个 token。

例如 K=4,就是只预测最后 1/4 的 token。给定句子 [1,2,3,4,5,6,7,8] 和一种随机排列 [2,8,3,4,5,1,7,6],则只预测 7 和 6。论文中训练 XLNet-Large 时使用的 K 为 6,大约是预测末尾 14.3%的 token。

XLNet还将transformer-xl的两个最重要的技术点应用了进来,即相对位置编码与片段循环机制。具体内容详见TranformerXL这一部分。

模型参考

论文地址:https://arxiv.org/abs/1906.08237

代码地址:https://github.com/zihangdai/xlnet

模型代码如下(基于hugging face):

from dataclasses import dataclass
from typing import Optional, List, Tuple

import tensorflow as tf
from tensorflow.keras import layers
from transformers import shape_list
from transformers.activations_tf import get_tf_activation
from transformers.modeling_tf_utils import get_initializer
from transformers.tf_utils import stable_softmax
from transformers.utils import ModelOutput


class TFXLNetModel(tf.keras.Model):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)

        self.xlnet = TFXLNetMainLayer(config, name="xlnet")

    def call(
        self,
        input_ids=None,
        attention_mask=None,
        mems=None,
        perm_mask=None,
        target_mapping=None,
        token_type_ids=None,
        input_mask=None,
        head_mask=None,
        inputs_embeds=None,
        use_mems=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False
    ):
        outputs = self.xlnet(
            input_ids=input_ids,
            attention_mask=attention_mask,
            mems=mems,
            perm_mask=perm_mask,
            target_mapping=target_mapping,
            token_type_ids=token_type_ids,
            input_mask=input_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            use_mems=use_mems,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            training=training
        )
        return outputs


class TFXLNetMainLayer(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.config = config
        self.output_hidden_states = config.output_hidden_states
        self.output_attentions = config.output_attentions
        self.return_dict = config.return_dict

        self.mem_len = config.mem_len
        self.reuse_len = config.reuse_len
        self.d_model = config.d_model
        self.same_length = config.same_length
        self.attn_type = config.attn_type
        self.bi_data = config.bi_data
        self.clamp_len = config.clamp_len
        self.n_layer = config.n_layer
        self.use_bfloat16 = config.use_bfloat16
        self.initializer_range = config.initializer_range

        self.word_embedding = TFSharedEmbeddings(
            config.vocab_size,
            config.d_model,
            initializer_range=config.initializer_range,
            name="word_embedding"
        )
        self.layers = [TFXLNetLayer(config, name=f"layer_._{i}") for i in range(config.n_layer)]
        self.dropout = layers.Dropout(config.dropout)

        self.use_mems_eval = config.use_mems_eval
        self.use_mems_train = config.use_mems_train

    def build(self, input_shape):
        initializer = get_initializer(self.initializer_range)
        self.mask_emb = self.add_weight(
            shape=(1, 1, self.d_model),
            initializer=initializer,
            trainable=True,
            name="mask_emb"
        )

    def create_mask(self, qlen, mlen):
        """
              same_length=False:      same_length=True:
              <mlen > <  qlen >       <mlen > <  qlen >
           ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
             [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
        qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
             [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
           v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]
        """
        attn_mask = tf.ones([qlen, qlen])
        mask_u = tf.linalg.band_part(attn_mask, 0, -1)
        mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
        attn_mask_pad = tf.zeros([qlen, mlen])
        ret = tf.concat([attn_mask_pad, mask_u - mask_dia], axis=1)
        if self.same_length:
            mask_l = tf.linalg.band_part(attn_mask, -1, 0)
            ret = tf.concat([ret[:, : qlen] + mask_l - mask_dia, ret[:, qlen:]], axis=1)
        return ret

    def relative_positional_encoding(self, qlen, klen, bsz=None):
        freq_seq = tf.range(0, self.d_model, 2.0)
        inv_freq = 1 / (10000 ** (freq_seq / self.d_model))

        if self.attn_type == "bi":
            beg, end = klen, -qlen
        elif self.attn_type == "uni":
            beg, end = klen, -1
        else:
            raise ValueError(f"Unknown `attn_type` {self.attn_type}.")

        if self.bi_data:
            fwd_pos_seq = tf.range(beg, end, -1.0)
            bwd_pos_seq = tf.range(-beg, -end, 1.0)

            if self.clamp_len > 0:
                fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
                bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)

            if bsz is not None:
                if bsz % 2 != 0:
                    raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
            else:
                fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
                bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
            pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
        else:
            fwd_pos_seq = tf.range(beg, end, -1.0)
            if self.clamp_len > 0:
                fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
            pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
        return pos_emb

    @staticmethod
    def positional_embedding(pos_seq, inv_freq, bsz=None):
        sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
        pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
        pos_emb = pos_emb[:, None, :]

        if bsz is not None:
            pos_emb = tf.tile(pos_emb, [1, bsz, 1])
        return pos_emb

    def cache_mem(self, curr_out, prev_mem):
        if self.reuse_len is not None and self.reuse_len > 0:
            curr_out = curr_out[: self.reuse_len]

        if self.mem_len is None or self.mem_len == 0:
            cutoff = 0
        else:
            cutoff = -self.mem_len
        if prev_mem is None:
            new_mem = curr_out[cutoff:]
        else:
            new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff: ]
        return tf.stop_gradient(new_mem)

    def call(
        self,
        input_ids=None,
        attention_mask=None,
        mems=None,
        perm_mask=None,
        target_mapping=None,
        token_type_ids=None,
        input_mask=None,
        head_mask=None,
        inputs_embeds=None,
        use_mems=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):
        if training and use_mems is None:
            use_mems = self.use_memes_train
        else:
            use_mems = self.use_mems_eval

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_ids = tf.transpose(input_ids, perm=(1, 0))
            qlen, bsz = shape_list(input_ids)[: 2]
        elif inputs_embeds is not None:
            inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
            qlen, bsz = shape_list(inputs_embeds)[: 2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
        # input_mask中的0、1与attention_mask中0、1含义相反
        # input_mask中0表示不遮掩,1表示遮掩
        input_mask = tf.transpose(token_type_ids, perm=(1, 0)) if input_mask is not None else None
        # attention_mask中1表示不遮掩,0表示遮掩
        attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
        # xlnet实现随机排列是通过perm_mask来实现的
        # perm_mask[k, i, j]=1表示在batch k中第i个单词可以看到第j个单词
        # 比如序列1、2、3、4的随机排序3、2、4、1对应的perm_mask
        # [[1 0 0 0]
        # [1 1 0 1]
        # [1 1 1 1]
        # [1 0 0 1]]
        # perm_mask中第3行只有第3列为1,因为3在最前面,只能看到自己
        # 第4行中第2、3、4列都为1,也就是说4能够看到2、3和自己
        perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
        # target_mapping[k, i, j]表示在batch k中第i个预测的单词在序列的第j个位置
        # 用于预训练任务中,在下游任务中应设置为None
        target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None

        mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
        klen = mlen + qlen

        if self.attn_type == "uni":
            attn_mask = self.create_mask(qlen, mlen)
            attn_mask = attn_mask[:, :, None, None]
        elif self.attn_type == "bi":
            attn_mask = None
        else:
            raise ValueError(f"Unsupported attention type: {self.attn_type}")

        if input_mask is None and attention_mask is not None:
            one_cst = tf.constant(1.0)
            input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype)
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None, :, :] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None, :, :]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            if mlen > 0:
                mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
                data_mask = tf.concat([mems_mask, data_mask], axis=1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)

        if attn_mask is not None:
            # non_tgt_mask对比attn_mask的对角线,由1变成0
            # 也就是说non_tgt_mask可以看到自身
            # non_tgt_mask参与计算content stream
            non_tgt_mask = -tf.eye(qlen)
            if mlen > 0:
                non_tgt_mask = tf.concat([tf.zeros([qlen, qlen]), non_tgt_mask], axis=-1)
            non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
        else:
            non_tgt_mask = None

        # Word embedding
        if inputs_embeds is not None:
            word_emb_k = inputs_embeds
        else:
            word_emb_k = self.word_embedding(input_ids)
        # output_h为content stream,表示初始输入的词向量
        # shape:(qlen, bsz, d_model)
        output_h = self.dropout(word_emb_k, training=training)
        if target_mapping is not None:
            word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
            # output_g为query stream,表示初始输入的位置向量
            output_g = self.dropout(word_emb_q, training=training)
        else:
            output_g = None

        # Segment embedding
        if token_type_ids is not None:
            if mlen > 0:
                mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.type)
                cat_ids = tf.concat([mem_pad, token_type_ids], 0)
            else:
                cat_ids = token_type_ids
            # 1表示token_type_ids中位置i的token参与到位置j的计算时,对应的分句不是同一个分句
            # shape: (qlen, klen, bsz)
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
                dtype=token_type_ids.dtype
            )
            # 将token_type_ids转化成one-hot形式
            # shape: (qlen, klen, bsz, 2)
            seg_mat = tf.one_hot(seg_mat, 2)
        else:
            seg_mat = None

        # Position embedding
        pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
        pos_emb = self.dropout(pos_emb, training=training)

        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.n_layer

        new_mems = ()
        if mems is None:
            mems = [None] * len(self.layers)

        attentions = [] if output_attentions else None
        hidden_states = [] if output_hidden_states else None
        for i, layer_module in enumerate(self.layers):
            if use_mems:
                new_mems = new_mems + (self.cache_mem(output_h, mems[i]))
            if output_hidden_states:
                hidden_states.append((output_h, output_g) if output_g is not None else output_h)

            outputs = layer_module(
                output_h,
                output_g,
                non_tgt_mask,
                attn_mask,
                pos_emb,
                seg_mat,
                mems[i],
                target_mapping,
                head_mask[i],
                output_attentions,
                training=training
            )
            output_h, output_g = outputs[: 2]
            if output_attentions:
                attentions.append(outputs[2])

        if output_hidden_states:
            hidden_states.append((output_h, output_g) if output_g is not None else output_h)

        output = self.dropout(output_g if output_g is not None else output_h, training=training)

        output = tf.transpose(output, perm=(1, 0, 2))

        if not use_mems:
            new_mems = None
        if output_hidden_states:
            if output_g is not None:
                hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
            else:
                hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
        if output_attentions:
            if target_mapping is not None:
                attentions = tuple(
                    tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
                )
            else:
                attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)

        if not return_dict:
            return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)

        return TFXLNetModelOutput(
            last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
        )


class TFSharedEmbeddings(layers.Layer):
    def __init__(self, vocab_size, hidden_size, initializer_range, **kwargs):
        super().__init__(**kwargs)

        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range

    def build(self, input_shape):
        self.weight = self.add_weight(
            "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
        )
        super().build(input_shape)

    def call(self, inputs, mode="embedding"):
        if mode == "embedding":
            return self._embedding(inputs)
        elif mode == "linear":
            return self._linear(inputs)
        else:
            raise ValueError(f"mode {mode} is not valid.")

    def _embedding(self, input_ids):
        return tf.gather(self.weight, input_ids)

    def _linear(self, inputs):
        first_dims = shape_list(inputs)[-1]
        x = tf.reshape(inputs, [-1, self.hidden_size])
        logits = tf.matmul(x, self.weight, transpose_b=True)
        return tf.reshape(logits, [first_dims] + [self.vocab_size])


class TFXLNetLayer(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
        self.ff = TFXLNetFeedForward(config, name="ff")
        self.dropout = layers.Dropout(config.dropout)

    def call(
            self,
            output_h,
            output_g,
            non_tgt_mask,
            attn_mask,
            pos_emb,
            seg_mat,
            mems,
            target_mapping,
            head_mask,
            output_attentions,
            training=False
    ):
        outputs = self.rel_attn(
            output_h,
            output_g,
            non_tgt_mask,
            attn_mask,
            pos_emb,
            seg_mat,
            mems,
            target_mapping,
            head_mask,
            output_attentions,
            training=training
        )
        output_h, output_g = outputs[: 2]

        if output_g is not None:
            output_g = self.ff(output_g, training=training)
        output_h = self.ff(output_h, training=training)

        outputs = (output_h, output_g) + outputs[2:]
        return outputs


class TFXLNetRelativeAttention(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        if config.d_model % config.n_head != 0:
            raise ValueError(
                f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
                f"heads ({config.n_head}"
            )

        self.n_head = config.n_head
        self.d_head = config.d_head
        self.d_model = config.d_model
        self.scale = 1 / (config.d_head ** 0.5)
        self.initializer_range = config.initializer_range
        self.output_attentions = config.output_attentions

        self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        self.dropout = layers.Dropout(config.dropout)

    def rel_attn_core(
            self,
            q_head,
            k_head_h,
            v_head_h,
            k_head_r,
            seg_mat,
            attn_mask,
            head_mask,
            output_attentions,
            training=False
    ):
        ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)

        bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
        bd = self.rel_shift(bd, klen=shape_list(ac[1]))

        if seg_mat is None:
            ef = 0
        else:
            ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
            ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)

        attn_score = (ac + bd + ef) * self.scale
        if attn_mask is not None:
            if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
                attn_score = attn_score - 65500 * attn_mask
            else:
                attn_score = attn_score - 1e30 * attn_mask

        attn_prob = stable_softmax(attn_score, axis=1)
        attn_prob = self.dropout(attn_prob, training=training)

        if head_mask is not None:
            attn_prob = attn_prob * head_mask

        attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)

        if output_attentions:
            return attn_vec, attn_prob
        return attn_vec

    def rel_shift(self, x, klen=-1):
        x_size = shape_list(x)
        x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
        x = x[1:, ...]
        x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
        x = x[:, 0:klen, :, :]
        return x

    def build(self, input_shape):
        initializer = get_initializer(self.initializer_range)
        self.q = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
        )
        self.k = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
        )
        self.v = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
        )
        self.o = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
        )
        self.r = self.add_weight(
            shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
        )
        self.r_r_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
        )
        self.r_s_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
        )
        self.r_w_bias = self.add_weight(
            shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
        )
        self.seg_embed = self.add_weight(
            shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
        )
        super().build(input_shape)

    def post_attention(self, h, attn_vec, residual=True, training=False):
        # shape: (..., n_head, d_head) - > (..., d_model)
        attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
        attn_out = self.dropout(attn_out, training=training)
        # 残差连接
        if residual:
            attn_out = attn_out + h
        output = self.layer_norm(attn_out)

        return output

    def call(
            self,
            h,
            g,
            attn_mask_h,
            attn_mask_g,
            r,
            seg_mat,
            mems,
            target_mapping,
            head_mask,
            output_attentions,
            training=False
    ):
        if g is not None:
            if mems is not None and len(shape_list(mems)) > 1:
                # shape: (mlen+qlen, bsz, d_model)
                cat = tf.concat([mems, h], axis=0)
            else:
                # shape: (qlen, bsz, d_model)
                cat = h
            # h stream
            # [qlen, bsz, n_head, d_head]
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)

            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)

            attn_vec_h = self.rel_attn_core(
                q_head_h,
                k_head_h,
                v_head_h,
                k_head_r,
                seg_mat,
                attn_mask_h,
                head_mask,
                output_attentions,
                training=training
            )
            if output_attentions:
                attn_vec_h, attn_prob_h = attn_vec_h
            output_h = self.post_attention(h, attn_vec_h, training=training)

            # g stream
            q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
            # target_mapping
            # shape: (num_predict, qlen, bsz)
            # 一般而言num_predict = qlen
            if target_mapping is not None:
                q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
                attn_vec_g = self.rel_attn_core(
                    q_head_g,
                    k_head_h,
                    v_head_h,
                    k_head_r,
                    seg_mat,
                    attn_mask_g,
                    head_mask,
                    output_attentions,
                    training=training
                )
                if output_attentions:
                    attn_vec_g, attn_prob_g = attn_vec_g
                attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
            else:
                attn_vec_g = self.rel_attn_core(
                    q_head_g,
                    k_head_h,
                    v_head_h,
                    k_head_r,
                    seg_mat,
                    attn_mask_g,
                    head_mask,
                    output_attentions,
                    training=training
                )
                if output_attentions:
                    attn_vec_g, attn_prob_g = attn_vec_g

            output_g = self.post_attention(g, attn_vec_g, training=training)

            if output_attentions:
                attn_prob = attn_prob_h, attn_prob_g
        else:
            if mems is not None and len(shape_list(mems)) > 1:
                cat = tf.concat([mems, h], axis=0)
            else:
                cat = h
            q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
            k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
            v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
            k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
            attn_vec = self.rel_attn_core(
                q_head_h,
                k_head_h,
                v_head_h,
                k_head_r,
                seg_mat,
                attn_mask_h,
                head_mask,
                output_attentions,
                training=training
            )
            if output_attentions:
                attn_vec, attn_prob = attn_vec

            output_h = self.post_attention(h, attn_vec, training=training)
            output_g = None

        outputs = (output_h, output_g)
        if output_attentions:
            outputs = outputs + (attn_prob, )
        return outputs


class TFXLNetFeedForward(layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
        self.layer_1 = tf.keras.layers.Dense(
            config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
        )
        self.layer_2 = tf.keras.layers.Dense(
            config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
        )
        self.dropout = tf.keras.layers.Dropout(config.dropout)
        if isinstance(config.ff_activation, str):
            self.activation_function = get_tf_activation(config.ff_activation)
        else:
            self.activation_function = config.ff_activation

    def call(self, inp, training=False):
        output = inp
        output = self.layer_1(output)
        output = self.activation_function(output)
        output = self.dropout(output, training=training)
        output = self.layer_2(output)
        output = self.dropout(output, training=training)
        output = self.layer_norm(output + inp)
        return output


@dataclass
class TFXLNetModelOutput(ModelOutput):
    last_hidden_state: tf.Tensor = None
    mems: Optional[List[tf.Tensor]] = None
    hidden_states: Optional[Tuple[tf.Tensor]] = None
    attentions: Optional[Tuple[tf.Tensor]] = None

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

不负韶华ღ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值