1. Mask
mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。
- padding mask:处理非定长序列,区分padding和非padding部分,如在RNN等模型和Attention机制中的应用等
- equence mask:防止标签泄露,如:Transformer decoder中的mask矩阵,BERT中的[Mask]位,XLNet中的mask矩阵等
1.1 Padding Mask
因为每个批次输入序列长度是不一样,需要对输入序列进行对齐。给较短的序列后面填充 0,对于输入太长的序列,截取左边的内容,把多余的直接舍弃。这些填充的位置,没什么意义,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。
具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 softmax,这些位置的概率就会接近0!
而我们的 padding mask 实际上是一个张量,每个值都是一个Boolean,值为 false 的地方就是我们要进行处理的地方。
1.2 Sequence mask
sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。
具体做法:产生一个上三角矩阵,上三角的值全为0。把这个矩阵作用在每一个序列上。
对于 decoder 的 self-attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。
其他情况,attn_mask 一律等于 padding mask。
代码:
import torch
def padding_mask(seq, pad_idx):
return (seq != pad_idx).unsqueeze(-2) # [B, 1, L]
def sequence_mask(seq):
batch_size, seq_len = seq.size()
mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
return mask
def test():
# 以最简化的形式测试Transformer的两种mask
seq = torch.LongTensor([[1,2,0]]) # batch_size=1, seq_len=3,padding_idx=0
embedding = torch.nn.Embedding(num_embeddings=3, embedding_dim=10, padding_idx=0)
query, key = embedding(seq), embedding(seq)
scores = torch.matmul(query, key.transpose(-2, -1))
mask_p = padding_mask(seq, 0)
mask_s = sequence_mask(seq)
mask_decoder = mask_p & mask_s # 结合 padding mask 和 sequence mask
scores_encoder = scores.masked_fill(mask_p==0, -1e9) # 对于scores,在mask==0的位置填充
scores_decoder = scores.masked_fill(mask_decoder==0, -1e9)
test()
参考:Transform详解_霜叶的博客-CSDN博客_transform