模型概述
XLNet 中提出了一种比较有意思的观点,将当前预训练模型分为了两类 AR (Auto Regression,自回归) 和 AE (Auto Encoder,自编码器)。 XLNet 将 AR 和 AE 两种方法的优点结合起来,XLNet 使用了 PLM(Permutation Language Model,排列组合语言模型)实现这一目的。
另外,由于使用排列组合模型,使用Transformer的自注意力会导致不知道预测的是哪一个token,XLNet通过双流自注意力机制来实现目标位置感知。
模型优化
排列组合语言模型
自回归语言模型就是根据上文内容预测下一个可能跟随的单词,也就是常说的自左向右的语言模型,或者反过来也行,就是根据下文预测前面的单词。

自回归语言模型有优点有缺点,缺点是只能利用上文或者下文的信息,不能同时利用上文和下文的信息,它的优点,其实跟下游NLP任务有关,比如生成类NLP任务,比如文本摘要,机器翻译等,在实际生成内容的时候,就是从左向右的,自回归语言模型天然匹配这个过程。
GPT 就是典型的自回归语言模型。ELMO尽管看上去利用了上文,也利用了下文,但是本质上仍然是自回归LM,这个跟模型具体怎么实现有关系。ELMO是做了两个方向(从左到右以及从右到左两个方向的语言模型),但是是分别有两个方向的自回归LM,然后把LSTM的两个方向的隐节点状态拼接到一起,来体现双向语言模型这个事情的。所以其实是两个自回归语言模型的拼接,本质上仍然是自回归语言模型。
自编码语言模型是根据上下文单词来预测上下文中的单词。它能比较自然地融入双向语言模型,同时看到被预测单词的上文和下文。优缺点正好和自回归LM反过来,它能比较自然地融入双向语言模型,同时看到被预测单词的上文和下文,缺点主要在输入侧引入[Mask]标记,导致预训练阶段和Fine-tuning阶段不一致的问题,因为Fine-tuning阶段是看不到[Mask]标记的。

BERT通过在输入X中随机Mask掉一部分单词,然后根据上下文单词来预测这些被Mask掉的单词。
排列组合语言模型将句子中的单词随机排列,然后采用自回归的方式预测末尾的几个单词。这样一来,在预测单词的时候就可以同时利用该单词双向的信息,并且能学到单词间的依赖。

XLNet 中通过 Attention Mask 实现 PLM,而无需真正修改句子 token 的顺序。例如原来的句子是 [1,2,3,4],如果随机生成的序列时 [3,2,4,1],则输入到 XLNet 的句子仍然是 [1,2,3,4],但是掩码需要修改成下图。

图中的掩码矩阵,红色表示不遮掩,白色表示遮掩。第 1 行表示 token 1 的掩码,在序列 3241中处于最后,能够看到前面的324,因此第一行的第2、3、4的圈圈都是红色,表示不遮掩。第2行表示token2的掩码,在序列3241中位于第二个,它只能看到前面的3,因此第二行的第3个圈圈为红色。第3行表示token3的掩码,在序列3241中处于最前面,因为它前面没有任何token,因此它看不到任何的token,所以4个圈圈都应为白色,表示遮掩。
双流自注意力:
XLNet 打乱了句子的顺序,这时在预测的时候 token 的位置信息会非常重要,同时在预测的时候也必须将 token 的内容信息遮掩起来 (否则输入包含了要预测的内容信息,模型就无法学到知识)。也就是说 XLNet 需要看到 token 的位置信息,但是又不能看到 token 的内容信息,因此 XLNet 采用了两个 Stream 实现这一目的:
Query Stream,对于每一个 token,其对应的 Query Stream 只包含了该 token 的位置信息,注意是 token 在原始句子的位置信息,不是重新排列的位置信息。
Content Stream,对于每一个 token,其对应的 Content Stream 包含了该 token 的内容信息。
Query Stream 计算:
Query Stream 用 g表示,Content Stream 用 h 表示,使用 Query Stream 对要预测的位置进行预测的时候,Q (Query) 向量是用 g 计算得到的,包含该位置的位置信息,而 K (Key) 和 V (Value) 是用 h 计算的,包含其他 token 的内容信息。下图展示了如何通过当前层的 g 计算下一层 g 的过程,图中的排列是 [3,2,4,1],计算的 token 是 1。

可以看到在计算 token 1 的 Q 向量时,只使用了 token 1 的 Query Stream g,即模型只得到 token 1 的位置信息。而向量 K,V 使用 token 3, 2, 4 进行计算,所以模型可以得到 token 3, 2, 4 的内容信息。因为 token 1 是排列 [3,2,4,1] 的最后一位。这一个过程的掩码矩阵和上一节的是一样的 ,对角线上都为白色,即遮掩当前预测位置的内容信息 h。

Content Stream 计算:
Content Stream 包含了 token 的内容信息,因为 XLNet 的层数很多,需要将 token 的内容传递到下一层。这一层的 Q, K, V 都是利用 h 计算的。Content Stream 的计算如下图所示。

可以看到,在计算下一层的 h1时,也会利用 token 1 当前的内容信息,这样就可以将 token 的内容传递到下一层,但是注意 XLNet 在预测时只是用 g (Query Stream)。计算 Content Stream 时候的掩码矩阵如下图所示。

和 Query Stream 的掩码矩阵区别在于对角线,Content Stream 不遮掩对角线,使得当前 token 的信息可以传递到下一层。
将 Query Stream 和 Content Stream 组合在一起,如下图所示。

图中最下面的一层是输入层,其中 e(x) 初始为单词的词向量,而 w 初始为一个可训练的向量。代码如下,如果有什么看不懂的地方可以留言。
class TFXLNetRelativeAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.d_model % config.n_head != 0:
raise ValueError(
f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
f"heads ({config.n_head}"
)
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = layers.Dropout(config.dropout)
def rel_attn_core(
self,
q_head,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask,
head_mask,
output_attentions,
training=False
):
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift(bd, klen=shape_list(ac[1]))
if seg_mat is None:
ef = 0
else:
ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training)
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if output_attentions:
return attn_vec, attn_prob
return attn_vec
def rel_shift(self, x, klen=-1):
x_size = shape_list(x)
x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
x = x[1:, ...]
x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
x = x[:, 0:klen, :, :]
return x
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
)
self.k = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
)
self.v = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
)
self.o = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
)
self.r = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
)
self.r_r_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
)
self.r_s_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
)
self.r_w_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
)
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
def post_attention(self, h, attn_vec, residual=True, training=False):
# shape: (..., n_head, d_head) - > (..., d_model)
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
# 残差连接
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def call(
self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
if g is not None:
if mems is not None and len(shape_list(mems)) > 1:
# shape: (mlen+qlen, bsz, d_model)
cat = tf.concat([mems, h], axis=0)
else:
# shape: (qlen, bsz, d_model)
cat = h
# h stream
# [qlen, bsz, n_head, d_head]
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec_h = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
output_h = self.post_attention(h, attn_vec_h, training=training)
# g stream
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
output_g = self.post_attention(g, attn_vec_g, training=training)
if output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec, attn_prob = attn_vec
output_h = self.post_attention(h, attn_vec, training=training)
output_g = None
outputs = (output_h, output_g)
if output_attentions:
outputs = outputs + (attn_prob, )
return outputs
XLNet 将句子重新排列,然后根据排列后的顺序使用 AR 方式预测,但是由于句子是随机排列的,会导致优化比较困难且收敛速度慢。因此 XLNet 采用了 Partial Prediction (部分预测) 的方式进行训练,对于排列后的句子,只预测句子末尾的 1/K 个 token。
例如 K=4,就是只预测最后 1/4 的 token。给定句子 [1,2,3,4,5,6,7,8] 和一种随机排列 [2,8,3,4,5,1,7,6],则只预测 7 和 6。论文中训练 XLNet-Large 时使用的 K 为 6,大约是预测末尾 14.3%的 token。
XLNet还将transformer-xl的两个最重要的技术点应用了进来,即相对位置编码与片段循环机制。具体内容详见TranformerXL这一部分。
模型参考
论文地址:https://arxiv.org/abs/1906.08237
代码地址:https://github.com/zihangdai/xlnet
模型代码如下(基于hugging face):
from dataclasses import dataclass
from typing import Optional, List, Tuple
import tensorflow as tf
from tensorflow.keras import layers
from transformers import shape_list
from transformers.activations_tf import get_tf_activation
from transformers.modeling_tf_utils import get_initializer
from transformers.tf_utils import stable_softmax
from transformers.utils import ModelOutput
class TFXLNetModel(tf.keras.Model):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.xlnet = TFXLNetMainLayer(config, name="xlnet")
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
outputs = self.xlnet(
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training
)
return outputs
class TFXLNetMainLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.return_dict = config.return_dict
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
self.d_model = config.d_model
self.same_length = config.same_length
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.n_layer = config.n_layer
self.use_bfloat16 = config.use_bfloat16
self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(
config.vocab_size,
config.d_model,
initializer_range=config.initializer_range,
name="word_embedding"
)
self.layers = [TFXLNetLayer(config, name=f"layer_._{i}") for i in range(config.n_layer)]
self.dropout = layers.Dropout(config.dropout)
self.use_mems_eval = config.use_mems_eval
self.use_mems_train = config.use_mems_train
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(
shape=(1, 1, self.d_model),
initializer=initializer,
trainable=True,
name="mask_emb"
)
def create_mask(self, qlen, mlen):
"""
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], axis=1)
if self.same_length:
mask_l = tf.linalg.band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, : qlen] + mask_l - mask_dia, ret[:, qlen:]], axis=1)
return ret
def relative_positional_encoding(self, qlen, klen, bsz=None):
freq_seq = tf.range(0, self.d_model, 2.0)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == "bi":
beg, end = klen, -qlen
elif self.attn_type == "uni":
beg, end = klen, -1
else:
raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
if self.bi_data:
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
if bsz is not None:
if bsz % 2 != 0:
raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
return pos_emb
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = tf.tile(pos_emb, [1, bsz, 1])
return pos_emb
def cache_mem(self, curr_out, prev_mem):
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
cutoff = 0
else:
cutoff = -self.mem_len
if prev_mem is None:
new_mem = curr_out[cutoff:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff: ]
return tf.stop_gradient(new_mem)
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
if training and use_mems is None:
use_mems = self.use_memes_train
else:
use_mems = self.use_mems_eval
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)[: 2]
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[: 2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
# input_mask中的0、1与attention_mask中0、1含义相反
# input_mask中0表示不遮掩,1表示遮掩
input_mask = tf.transpose(token_type_ids, perm=(1, 0)) if input_mask is not None else None
# attention_mask中1表示不遮掩,0表示遮掩
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
# xlnet实现随机排列是通过perm_mask来实现的
# perm_mask[k, i, j]=1表示在batch k中第i个单词可以看到第j个单词
# 比如序列1、2、3、4的随机排序3、2、4、1对应的perm_mask
# [[1 0 0 0]
# [1 1 0 1]
# [1 1 1 1]
# [1 0 0 1]]
# perm_mask中第3行只有第3列为1,因为3在最前面,只能看到自己
# 第4行中第2、3、4列都为1,也就是说4能够看到2、3和自己
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
# target_mapping[k, i, j]表示在batch k中第i个预测的单词在序列的第j个位置
# 用于预训练任务中,在下游任务中应设置为None
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == "bi":
attn_mask = None
else:
raise ValueError(f"Unsupported attention type: {self.attn_type}")
if input_mask is None and attention_mask is not None:
one_cst = tf.constant(1.0)
input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype)
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None, :, :] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None, :, :]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
if mlen > 0:
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)
if attn_mask is not None:
# non_tgt_mask对比attn_mask的对角线,由1变成0
# 也就是说non_tgt_mask可以看到自身
# non_tgt_mask参与计算content stream
non_tgt_mask = -tf.eye(qlen)
if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, qlen]), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
else:
non_tgt_mask = None
# Word embedding
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
# output_h为content stream,表示初始输入的词向量
# shape:(qlen, bsz, d_model)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
# output_g为query stream,表示初始输入的位置向量
output_g = self.dropout(word_emb_q, training=training)
else:
output_g = None
# Segment embedding
if token_type_ids is not None:
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.type)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
else:
cat_ids = token_type_ids
# 1表示token_type_ids中位置i的token参与到位置j的计算时,对应的分句不是同一个分句
# shape: (qlen, klen, bsz)
seg_mat = tf.cast(
tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
dtype=token_type_ids.dtype
)
# 将token_type_ids转化成one-hot形式
# shape: (qlen, klen, bsz, 2)
seg_mat = tf.one_hot(seg_mat, 2)
else:
seg_mat = None
# Position embedding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=training)
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
new_mems = ()
if mems is None:
mems = [None] * len(self.layers)
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layers):
if use_mems:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]))
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
output_attentions,
training=training
)
output_h, output_g = outputs[: 2]
if output_attentions:
attentions.append(outputs[2])
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training)
output = tf.transpose(output, perm=(1, 0, 2))
if not use_mems:
new_mems = None
if output_hidden_states:
if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if output_attentions:
if target_mapping is not None:
attentions = tuple(
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
)
else:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not return_dict:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
return TFXLNetModelOutput(
last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
)
class TFSharedEmbeddings(layers.Layer):
def __init__(self, vocab_size, hidden_size, initializer_range, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
self.weight = self.add_weight(
"weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
)
super().build(input_shape)
def call(self, inputs, mode="embedding"):
if mode == "embedding":
return self._embedding(inputs)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError(f"mode {mode} is not valid.")
def _embedding(self, input_ids):
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
first_dims = shape_list(inputs)[-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, [first_dims] + [self.vocab_size])
class TFXLNetLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
self.ff = TFXLNetFeedForward(config, name="ff")
self.dropout = layers.Dropout(config.dropout)
def call(
self,
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
outputs = self.rel_attn(
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=training
)
output_h, output_g = outputs[: 2]
if output_g is not None:
output_g = self.ff(output_g, training=training)
output_h = self.ff(output_h, training=training)
outputs = (output_h, output_g) + outputs[2:]
return outputs
class TFXLNetRelativeAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.d_model % config.n_head != 0:
raise ValueError(
f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
f"heads ({config.n_head}"
)
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = layers.Dropout(config.dropout)
def rel_attn_core(
self,
q_head,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask,
head_mask,
output_attentions,
training=False
):
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift(bd, klen=shape_list(ac[1]))
if seg_mat is None:
ef = 0
else:
ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training)
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if output_attentions:
return attn_vec, attn_prob
return attn_vec
def rel_shift(self, x, klen=-1):
x_size = shape_list(x)
x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
x = x[1:, ...]
x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
x = x[:, 0:klen, :, :]
return x
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
)
self.k = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
)
self.v = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
)
self.o = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
)
self.r = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
)
self.r_r_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
)
self.r_s_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
)
self.r_w_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
)
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
def post_attention(self, h, attn_vec, residual=True, training=False):
# shape: (..., n_head, d_head) - > (..., d_model)
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
# 残差连接
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def call(
self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
if g is not None:
if mems is not None and len(shape_list(mems)) > 1:
# shape: (mlen+qlen, bsz, d_model)
cat = tf.concat([mems, h], axis=0)
else:
# shape: (qlen, bsz, d_model)
cat = h
# h stream
# [qlen, bsz, n_head, d_head]
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec_h = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
output_h = self.post_attention(h, attn_vec_h, training=training)
# g stream
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
# target_mapping
# shape: (num_predict, qlen, bsz)
# 一般而言num_predict = qlen
if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
output_g = self.post_attention(g, attn_vec_g, training=training)
if output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec, attn_prob = attn_vec
output_h = self.post_attention(h, attn_vec, training=training)
output_g = None
outputs = (output_h, output_g)
if output_attentions:
outputs = outputs + (attn_prob, )
return outputs
class TFXLNetFeedForward(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.layer_1 = tf.keras.layers.Dense(
config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
)
self.layer_2 = tf.keras.layers.Dense(
config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
if isinstance(config.ff_activation, str):
self.activation_function = get_tf_activation(config.ff_activation)
else:
self.activation_function = config.ff_activation
def call(self, inp, training=False):
output = inp
output = self.layer_1(output)
output = self.activation_function(output)
output = self.dropout(output, training=training)
output = self.layer_2(output)
output = self.dropout(output, training=training)
output = self.layer_norm(output + inp)
return output
@dataclass
class TFXLNetModelOutput(ModelOutput):
last_hidden_state: tf.Tensor = None
mems: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
159

被折叠的 条评论
为什么被折叠?



