生成模型的中Attention Mask说明

生成模型中的Attention Mask说明

最近在做文本生成任务,例如诗歌生成,问题生成,摘要生成等,使用了Bart模型,CPT模型,mt5模型,t5模型等。生成模型是基于Seq-to-Seq(Encoder-Decoder)结构,输入的文本经过Encoder编码得到向量再输入到Decoder解码生成一个文本。Encoder和Decoder会使用多层transformer 以及self-attention,需要注意Encoder和Decoder中的attention mask使用,在Decoder中的self-attention当前时刻t的词只能关注到时刻0到t-1的词,无法关注到时刻t+1的词,是因为使用了在计算self-attention的时候加入了矩阵MASK M M M的右上角被mask为 − ∞ -\infty ,从而使得Decoder进行更好的生成。Decoder中的MASK M M M矩阵是由UNIfied pre-trained Language Model(UniLM)提出来的,对应于UniLM中的Seq-to-Seq LM。下面介绍Seq-to-Seq LM原理以及在bart中的使用。

Seq-to-Seq LM 介绍

Seq-to-Seq LM是UniLM提出来的,其中UniLM中有三种LM方式分别为Unidirectional LM,Bidirectional LM和Sequence-to-Sequence LM。如下图1所示。Unidirectional LM包括left-to-right and right-to-left LM,例如left-to-right LM当前词计算attention的时候只能使用当前词前面的词,后面的词无法使用到。
unilm mask

在 Seq-to-Seq LM中的self-attention计算中加入了一个MASK 矩阵 M M M,这个MASK 矩阵 M M M的右上角的元素是 − ∞ -\infty ,左下角的元素为0。
1. Multi-Layer Transformer
输入向量 { x i } i = 1 ∣ x ∣ \{x_{i}\}_{i=1}^{|x|} {xi}i=1x,第0层的transformer的输出状态 H 0 H^{0} H0记为 H 0 = [ x 1 , x 2 , … , x ∣ x ∣ ] H^{0} = [x_{1}, x_{2}, \dots, x_{|x|}] H0=[x1,x2,,xx] L L L层的Transformer的结果记为 H l = T r a n s f o r m e r l H l − 1 H^{l} = Transformer_{l} H^{l-1} Hl=TransformerlHl1 l ∈ [ 1 , L ] l\in[1, L] l[1,L] l l l层的self-attention计算过程如下:
Q = H l − 1 W l Q K = H l − 1 W l K V = H l − 1 W l V M i j = { 0 , a l l o w t o a t t e n d − ∞ , p r e v e n t f r o m a t t e n d i n g A l = s o f t m a x ( Q K T d k + M ) V Q = H^{l -1}W_{l}^{Q} \\ K = H^{l -1}W_{l}^{K} \\ V = H^{l -1}W_{l}^{V} \\ M_{ij}=\left\{ \begin{aligned} 0, & & allow to attend \\ -\infty, & & prevent from attending \end{aligned} \right. \\ A_{l} = softmax(\frac{QK^{T}}{\sqrt{d_{k}}} + M)V Q=Hl1WlQK=Hl1WlKV=Hl1WlVMij={0,,allowtoattendpreventfromattendingAl=softmax(dk QKT+M)V
其中 H l − 1 ∈ R ∣ x ∣ × d h H^{l-1}\in R^{|x|\times d_{h}} Hl1Rx×dh W l Q , W l K , W l V ∈ R d h × d k W_{l}^{Q},W_{l}^{K},W_{l}^{V}\in R^{d_{h}\times d_{k}} WlQ,WlK,WlVRdh×dk M ∈ R ∣ x ∣ × ∣ x ∣ M\in R^{|x|\times|x|} MRx×x
已UniLM的输入句子 S 1 S_{1} S1为[SOS] t1 t2 [EOS], 输出的句子 S 2 S_{2} S2为t3 t4 t5 [EOS],将 S 1 S_{1} S1 S 2 S_{2} S2拼接后得到句子[SOS] t1 t2 [EOS] t3 t4 t5 [EOS]输入模型,MASK M M M如下:
在这里插入图片描述
上图中矩阵MASK M M M右上角阴部部分的元素均为 − ∞ -\infty ,在self-attention中的softmax的时候加上MASK M M M使得self-attention关注不到 S 2 S_{2} S2的句子,使得模型有更好的生成能力。下面介绍MASK M M M 在bart中的使用。

Bart中Encoder和Decoder中的MASK M M M介绍

1. Bart中的self-attention
Bart中的self-attention计算和上面介绍的UniLM中的attention计算方式一样,在attention计算softmax的时候加入了MASK M M M矩阵,在代码中用attention mask代替,代码如下:
在这里插入图片描述

2. 在Bart中的Encoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],把attention mask中的元素为0的变为一个很大的负数,元素原来是1的位置处的元素为0,避免attention mask中元素为0的位置处的padding token对句子间的词之间的相关联程度的影响,相关代码如下:

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)

结果如下:
在这里插入图片描述
3. Bart中的Decoder中的MASK M M M
在Encoder中的MASK M M M 是根据attention mask 计算得到的,attention mask大小为 [ b a t c h s i z e , s e q l e n g t h ] [batch_size, seq_length] [batchsize,seqlength]有1和0构成。在encoder中的attention mask矩阵进行了扩维变为 [ b a t c h s i z e , 1 , s e q l e n g t h , s e q l e n g t h ] [batch_size, 1, seq_length, seq_length] [batchsize,1,seqlength,seqlength],在把attention mask中的元素为0设置为 ∞ \infty 或者为一个负无穷大的数,相关代码如下

def _make_causal_mask(input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), float("-inf"))
    mask_cond = torch.arange(mask.size(-1))
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
       # create causal mask
       # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
       combined_attention_mask = None
       if input_shape[-1] > 1:
           combined_attention_mask = _make_causal_mask(
               input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
           ).to(self.device)

       if attention_mask is not None:
           # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
           expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
           combined_attention_mask = (
               expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
           )

       return combined_attention_mask

decoder中 MASK M M M 结果如下:
在这里插入图片描述
在生成模型中decoder中的self-attention计算加入MASK M M M矩阵,使得模型具有更好的生成能力。如有错误,欢迎指证。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值