diffusers中的AttnProcessor源码解析(key_padding_mask和attn_mask如何在MSA中作用)

1. prepare_attention_mask

这里结合Mutil Head Attention了解下不同mask的作用。key_padding_maskattn_mask两个实际上都是作用到attn_output_weights来影响最终的output,前者专注处理序列中的<PAD>,而后者专注处理序列交叉中的“不可见”逻辑

首先先建立一个概念:多头的分头,分的是QKV的 dim维度

query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]
# 分头(split heads)
head_dim = dim // heads # heads=8
query=[batch_size, source_length, heads, dim//heads], key=[batch_size, target_len, heads, dim//heads], value=[batch_size, target_len, heads, dim//heads]

1.1 key_padding_mask

  • key_padding_mask,长度是(B, S),B为batch_size,S为源序列长度,即query的seq_len(NLP的token个数S/CV的patch个数HW),序列中没有到达max_len的token用<PAD>填充,key_padding_mask中对应的位置为True,计算attention时会将key中mask=True的部分省略掉。
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]

在这里插入图片描述
计算self-attention时,key_padding_mask只屏蔽key中的mask:即非mask的token作为query时,和sequence中所有非mask的token作为key计算self-attention;而mask的token也可以作为query,和sequence中所有非mask的token作为key计算self-attention。(mask的token不作为key token参与计算)因为就算mask的token作为key参与了计算,最后reshape会原来的形状后,也不使用padding的部分,所以这部分注意力的计算是冗余的。
在这里插入图片描述torch实现方法,key_padding_maskattn_mask上进行实现的。如下图的伪代码实现是(不考虑多头时 attn_mask.shape=[batch, seq_len, seq_len]):torch.baddbmm计算QK然后将attn_mask加到QK矩阵上,然后mask的部分就算负无穷-inf,再经过softmax就变为0.
在这里插入图片描述
a t t e n t i o n = S o f t m a x ( Q K T d k + a t t n _ m a s k ) ⋅ V attention=Softmax(\frac{QK^T}{\sqrt{d_k}}+attn\_mask)·V attention=Softmax(dk QKT+attn_mask)V在这里插入图片描述

# 模拟key_pad_mask加到attn_mask上
import torch
from einops import rearrange, repeat
batch_size, seq_len, dim = 1, 9, 8
key_pad_mask = torch.tensor([False, False, True, False, False, True, False, False, True]).unsqueeze(0)
# tensor([[False, False,  True, False, False,  True, False, False,  True]])
key_pad_mask = torch.where(key_pad_mask, float('-inf'), 0)
# tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf]])
key_pad_mask = repeat(key_pad_mask, 'b s -> b ss s', ss=seq_len)
'''
tensor([[[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf]]])
'''
# 假设用casual attention: 下三角attn_mask
attn_mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.where(attn_mask, float('-inf'), 0)  # attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
'''
tensor([[[-inf, 0., 0., 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, 0., 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''
attn_mask += key_pad_mask
'''
tensor([[[-inf, 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''

# query=[batch, seq_len, dim], key=[batch, tgt_len, dim], value=[batch, tgt_len, dim]
attn_score = torch.softmax(torch.baddbmm(attn_mask, query, key.transpose(-2, -1)), dim=-1)
attn_output = torch.bmm(attn_score, value)

1.2 attn_mask

  • attn_mask,长度是(B, source_length, target_length),其中B表示batch_sizesource_length表示源序列长度(Q的seq_len),target_length是目标序列长度(KV的seq_len),表示对权重矩阵做mask;
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]

如果考虑多头,则要在scaled_dot_product_attention之前,把attn_mask为每个head复制一份(diffusers中使用prepare_attention_mask函数实现):

  • 如果attn_mask的shape是4维度的,初始(batch, source_length, target_length),则unseqeeze出一个head维度,沿第1维度(heads维度)复制heads份,变成(batch, heads, source_length, target_length)
  • 如果attn_mask的shape是3维度的,初始(batch, source_length, target_length),直接将注意力掩码沿着第0维度(batch维度)重复head_size次,变成(batch x heads, source_length, target_length)

这样batch x heads个头[source_length, target_length]@[target_length, source_length] 的矩阵乘法后,分别相同batch的head使用相同的attn_mask然后再进行softmax

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

2. AttnProcessor

用于执行 self-attention 或 cross-attention:

class AttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states, *args)
        value = attn.to_v(encoder_hidden_states, *args)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
  • 6
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值