深度学习算法informer(时序预测)(二)(ProbSparse讲解)

 前言:对于 transformer,informer 做出的改进第一点是使用概率注意力机制(ProbSparse attention )代替 self attention

ProbSparse attention 是 informer 的核心,本节主要讲代码原理

本代码中

  • Q:查询张量,形状为 [B, H, L_Q, D] 
  • K:键张量,形状为 [B, H, L_K, E] 
  • sample_k:采样的键的数量
  • n_top:Top-K查询的数量,计算为 factor * ln(L_Q)
  • V:值张量,形状为 [B, H, L_V, D] 

B表示批次大小,H表示头的数量,L_K表示键的长度,E 表示键的嵌入维度,L_Q 表示查询的长度,D 表示查询/值的嵌入维度,L_V表示值的长度

ProbSparse attention 分为两步

一、计算注意力分数

1. 通过在K上添加新维度拓展到 [B, H, L_Q, L_K, E] 形状

2. 根据 sample_k 对其长度维度上进行采样得到 [B, H, L_Q, sample_k, E] 形状,计算 Q_K_sample = Q * K.T

3. 根据 Q_K_sample(形状为 [B, H, L_Q, sample_k] )计算 M 以及 M_top,基于M_top得到新的 Q(形状为 [B, H, n_top, D] )

4. 使用新的 Q 计算注意力分数(形状为 [B, H, n_top, L_K] )

代码如下

class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        """
        mask_flag:决定是否使用掩码,默认为True
        factor:缩放因子,用于计算采样的数量  factor * ln(Q长度)
        scale:决定是否缩放
        attention_dropout:注意力机制中的dropout率
        output_attention:决定是否输出注意力权重
        """
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape

        # calculate the sampled Q_K
        # K_expand.shape
        # torch.Size([32, 8, 96, 96, 64])
        # torch.Size([96, 25])
        # K_sample.shape
        # torch.Size([32, 8, 96, 25, 64])
        # Q_K_sample.shape
        # torch.Size([32, 8, 96, 25])

        # 通过在 K 上添加一个新的维度并扩展到 [B, H, L_Q, L_K, E] 的形状
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
        # 根据 index_sample 对 K_expand 进行采样,得到一个形状为 [B, H, L_Q, sample_k, E] 的张量
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]


        # 计算查询和采样的键之间的点积,得到一个形状为 [B, H, L_Q, sample_k] 的张量
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)

        # 找到具有稀疏度度量的Top-K查询
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]

        # 使用减少的Q计算Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None],
                     torch.arange(H)[None, :, None],
                     M_top, :] # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k

        # M.shape
        # torch.Size([32, 8, 96])
        # M_top.shape
        # torch.Size([32, 8, 25])
        # Q_reduce.shape
        # torch.Size([32, 8, 25, 64])
        # Q_K.shape
        # torch.Size([32, 8, 25, 96])
        return Q_K, M_top

二、计算最后输出

1. 基于注意力分数(形状为 [B, H, n_top, L_K] )和初始化的V计算得到结果(形状为 [B, H, n_top, D] )

2. 基于V的计算填充最后输出(若不使用掩码,则计算 V 在最后一个维度上的均值,形状为 [B, H, D],并在新添加的维度上扩展到 [B, H, L_Q, D];若使用掩码,则计算 V 在最后一个维度上的累计和,形状为 [B, H, L_V, D])

3. 最后输出形状为 [B, H, L_Q, D]

注意:最后输出形状第三维度中,n_top < L_Q,n_top部分基于1计算可得,L_Q - n_top部分基于2计算可得

代码如下

    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape

        if not self.mask_flag:
            # 计算 V 在最后一个维度上的均值,形状为 [B, H, D]
            V_sum = V.mean(dim=-2)
            # 将 V_sum 在新添加的维度上扩展到 [B, H, L_Q, D],并返回它的副本
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else: # 使用掩码
            assert(L_Q == L_V)
            # 断言查询长度 L_Q 和值的长度 L_V 相等,即只适用于自注意力(self-attention)
            contex = V.cumsum(dim=-2)

            # V.shape
            # torch.Size([32, 8, 72, 64])
            # contex.shape
            # torch.Size([32, 8, 72, 64])
        return contex
 

    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape

        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        # 使用softmax函数计算注意力权重,形状为 [B, H, n_top, L_V]
        attn = torch.softmax(scores, dim=-1)
        # 使用注意力权重 attn 和值 V 进行加权求和
        # 并将结果赋值到 context_in 的对应位置,对应位置形状为 [B, H, n_top, D]
        context_in[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   index, :] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            # 初始化一个全为1的张量,并将其除以 L_V,形成形状为 [B, H, L_V, L_V]
            attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
            # 将计算得到的注意力权重 attn 赋值到 attns 的对应位置
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)

    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2,1)
        keys = keys.transpose(2,1)
        values = values.transpose(2,1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 

        U_part = U_part if U_part<L_K else L_K
        u = u if u<L_Q else L_Q
        
        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 

        # 加入缩放系数
        scale = self.scale or 1./sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # 初始化上下文向量
        context = self._get_initial_context(values, L_Q)
        # 更新上下文向量
        context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
        
        return context.transpose(2,1).contiguous(), attn

掩码计算代码如下

class ProbMask():
    def __init__(self, B, H, L, index, scores, device="cpu"):
        """

        创建一个形状为 (L, scores.shape[-1]) 的全为 True 的布尔张量
        .to(device): 将张量移动到指定的设备(CPU 或 GPU)
        .triu(1): 获取该张量的上三角部分(不包括主对角线),下三角部分设为 False
        """
        _mask = torch.ones(L, scores.shape[-1], dtype=torch.bool).to(device).triu(1)
        """
        _mask[None, None, :]: 在前两个维度上增加两个维度,使其形状变为 (1, 1, L, scores.shape[-1])。
        .expand(B, H, L, scores.shape[-1]): 扩展张量,使其形状变为 (B, H, L, scores.shape[-1])
        """
        _mask_ex = _mask[None, None, :].expand(B, H, L, scores.shape[-1])
        """
        torch.arange(B)[:, None, None]: 创建一个形状为 (B, 1, 1) 的张量,用于索引批次维度。
        torch.arange(H)[None, :, None]: 创建一个形状为 (1, H, 1) 的张量,用于索引注意力头维度。
        index: 索引序列维度
        """
        indicator = _mask_ex[torch.arange(B)[:, None, None],
                             torch.arange(H)[None, :, None],
                             index, :].to(device)
        self._mask = indicator.view(scores.shape).to(device)
    
    @property
    def mask(self):
        return self._mask

  • 19
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值