一、Transformer模型中的SelfAttention机制
注意力机制的解释性博客比较多质量良莠不齐,推荐大家观看李宏毅老师关于注意力机制的讲解视频以及本人觉得对注意力机制讲解比较透彻的一篇博客[博客链接]。为更好解读注意力机制中attention-mask 的作用,现将注意力机制的原理进行总结。假设两个输入经过Wq、Wk、Wv矩阵(可训练)线性变换后获得q1=(1,2),q2=(0,1),k1=(1,0),k2=(0,1),v1=(1,0),v2=(0,1)向量。注意力机制核心就是向量q与向量k点乘后获得相似性分数(一个标量)。
同理 q2也与所有的k向量点乘获得,所有分数构成注意力分数矩阵。
attentionscore函数经过softmax函数每行都变成概率分布,每行的概率和为1。
将概率分布按行与每个v向量做叉乘运算,然后按行求和获得最终向量。
写成矩阵格式:
两个向量q与k相关度越高,点成后获得 attentionscore数值越大,softmax函数是一个单调函数,变成概率分布后,该attentionscore所占概率比重越大,与v向量相乘对最终结果b的贡献度和影响力也随之变大。
二、Attentionmask的作用
注意力机制中attentionmask的作用就是确定有效token的位置。有效token为1,填充token为0。例如一句话有3个单词,经过tokenizer分词变成整数编码后,句子起始位置增加填充字符[CLS]和句子末尾增加填充字符[SEP]后有效token数量为5个,假设规定输入张量固定长度为8,后面是填充字符[PAD],该输入的mask为[1,1,1,1,1,0,0,0]。
张量进入自注意力机制后,需要计算每个token与相邻的另外7个token的注意力分数。注意力分数矩阵变成概率分布时softmax计算量巨大,而且后三个填充字符[PAD]也会获得一定的概率分数,与v向量相乘会对最终结果有一定影响。这种情况不是我们所期望的,所以要在进行softmax函数运算之前,attentionscore要与attentionmask矩阵运算将末尾填充字符的概率分布影响力置为0,本以为是简单的相乘运算结果是相加运算,下面会介绍原因如何将填充字符的概率分布设置为0。
attention矩阵是一个8*8的矩阵,mask矩阵也需要由1*8矩阵变成8*8矩阵。 mask矩阵维度为[batch,seq_len],在本示例中batch=1,seq_len=8,为社么要在中间增加两个维度,这与multi-head self attention计算有关,增加两个维度后,与score进行计算时运用broadcast机制自动对齐2,3维度。
extended_attention_mask = attention_mask[:, None, None, :]
attentionmask 矩阵与attentionscore矩阵逐个元素相乘,每行中末尾3个score变成0影响力和贡献度消失,当初没看源码时我以为是这样操作的,但是我们忽略了一个问题就是softmax的计算公式分子是指数运算,exp(0)=1而不是0,attentionscore变成经过softmax运算变成概率分布时,每行后三个score是有数值的而不是0,与v向量相乘会对最终结果有一定影响。所以mask与score不能简单相乘而是要利用指数函数的负无穷接近0这个原理做文章,把mask码变换后与score相加,填充字符的score变成接近负无穷的数。BertModel源码中的处理公式
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
extended_attention_mask与attentionscore相加后在经过softmax函数运算,每行后三个[PAD]的score概率为0,与向量v相乘时不在有贡献值。
##############截取部分源码
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_layer)
三、Casualmask在decoder中的作用
在transformer的encoder中,selfattention是双向注意力机制每条样本的所有token都可以相互作用,但是在decoder中,一般要求只能看到对应token字符之前的token,后面的token则需要被掩盖住,这就出现了因果掩码casualmask。因果掩码的本质是一个下三角矩阵,还是举例上面的一行8个token的例子。我们获取mask掩码中token的数量8个,变成1*8的矩阵a和8*1的矩阵b,做关系运算a<=b,获得8行8列的下三角矩阵。
seq_len = attn_mask.shape[-1]
print(seq_len)
a = torch.arange(seq_len)[None,:]
b = torch.arange(seq_len)[:,None]
print(a<=b) #运用broadcast机制
tensor([[ True, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False],
[ True, True, True, True, False, False, False, False],
[ True, True, True, True, True, False, False, False],
[ True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True]])
获得的下三角矩阵还需要与attentionmask相乘,后三行的每行有效token是5个但是我们经过计算获得的下三角矩阵后三行是[6,7,8]超出有效字符数量,相乘后将无效字符设置为0。
huggingface的bert_modeling中的源码示例
def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
if device is not None:
warnings.warn(
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
)
else:
device = attention_mask.device
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[
torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
causal_mask,
],
axis=-1,
)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
return extended_attention_mask