Transform 相关知识(Mask)

1. Mask

 mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。

  • padding mask:处理非定长序列,区分padding和非padding部分,如在RNN等模型和Attention机制中的应用等
  • equence mask:防止标签泄露,如:Transformer decoder中的mask矩阵,BERT中的[Mask]位,XLNet中的mask矩阵等

1.1 Padding Mask

因为每个批次输入序列长度是不一样,需要对输入序列进行对齐。给较短的序列后面填充 0,对于输入太长的序列,截取左边的内容,把多余的直接舍弃。这些填充的位置,没什么意义,所以我们的attention机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数(负无穷),这样的话,经过 softmax,这些位置的概率就会接近0!

而我们的 padding mask 实际上是一个张量,每个值都是一个Boolean,值为 false 的地方就是我们要进行处理的地方。

1.2 Sequence mask


sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。

具体做法:产生一个上三角矩阵,上三角的值全为0。把这个矩阵作用在每一个序列上。

对于 decoder 的 self-attention,同时需要padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个mask相加作为attn_mask。
其他情况,attn_mask 一律等于 padding mask。

代码:

import torch

def padding_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)   # [B, 1, L]

def sequence_mask(seq):
    batch_size, seq_len = seq.size()
    mask = 1- torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),diagonal=1)
    mask = mask.unsqueeze(0).expand(batch_size, -1, -1)  # [B, L, L]
    return mask

def test():
    # 以最简化的形式测试Transformer的两种mask
    seq = torch.LongTensor([[1,2,0]]) # batch_size=1, seq_len=3,padding_idx=0
    embedding = torch.nn.Embedding(num_embeddings=3, embedding_dim=10, padding_idx=0)
    query, key = embedding(seq), embedding(seq)
    scores = torch.matmul(query, key.transpose(-2, -1))

    mask_p = padding_mask(seq, 0)
    mask_s = sequence_mask(seq)
    mask_decoder = mask_p & mask_s # 结合 padding mask 和 sequence mask

    scores_encoder = scores.masked_fill(mask_p==0, -1e9) # 对于scores,在mask==0的位置填充
    scores_decoder = scores.masked_fill(mask_decoder==0, -1e9)

test()

参考:Transform详解_霜叶的博客-CSDN博客_transform

           NLP 中的Mask全解 - 知乎

    

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值