(即插即用模块-Attention部分) 十九、(CVPR 2023) BRA 双层路由注意力

在这里插入图片描述

paper:BiFormer: Vision Transformer with Bi-Level Routing Attention
Code:https://github.com/rayleizhu/BiFormer


1、Bi-Level Routing Attention

Vision Transformer (ViT) 通过注意力机制有效地捕捉长距离依赖关系,但计算量巨大,限制了其在高分辨率图像上的应用。此外,现有的稀疏注意力机制大多采用静态模式或共享所有查询的键值对,难以适应不同查询的需求。

为此,论文提出了一种新的双层路由注意力(BRA),通过双层路由可以更灵活的分配具有内容感知的计算指令。其原理如下:(1)双层路由: 首先在粗粒度区域级别过滤掉无关的键值对,然后对剩余候选区域的并集进行细粒度 token-to-token 注意力。(2)区域图: 构建区域级别的亲和图,并通过剪枝保留每个节点的 top-k 连接,确定每个区域需要关注的 top-k 路由区域。(3)键值对收集: 将路由区域内的键值对收集起来,形成密集矩阵,以便进行高效的矩阵乘法运算。

对于一个输入特征 X,Bi-Level Routing Attention 的实现过程:

  1. 区域划分和输入投影: 将输入特征图划分为 SxS 个非重叠区域,并使用线性投影得到查询、键、值张量。
  2. 区域到区域路由: 通过在每个区域上应用平均池化,得到区域级的查询和键。然后通过矩阵乘法计算区域级查询和键的亲和图,保留每个节点的top-k连接,形成路由索引矩阵。
  3. token-to-token 注意力: 将路由区域内的键值对收集起来,形成密集矩阵,并应用 attention 计算,得到输出特征图。

Bi-Level Routing Attention 的优势在于:

  1. 动态、查询感知的稀疏模式: 适应不同查询的需求,提高计算效率。
  2. 高效计算: 利用稀疏性减少计算量,同时只涉及 GPU 友好的密集矩阵乘法。
  3. 性能优越: 在图像分类、目标检测、实例分割和语义分割等任务中取得了优异的性能。

Bi-Level Routing Attention 结构图:
在这里插入图片描述

2、BiFormer

在 Bi-Level Routing Attention 基础上,论文提出一种新的Transformer架构 BiFormer,其核心是 BRA,通过动态、查询感知的方式对计算进行更灵活的分配。BiFormer的整体结构是类似于 Swin Transformer 的金字塔结构, 使用四阶段金字塔结构,通过重叠 patch Embedding 和 patch Merging块 逐步降低输入空间分辨率并增加通道数。此外的主要成分是 BiFormer Block,其包含深度可分离卷积、BRA 模块和 2 层 MLP 模块,分别用于编码位置信息、跨位置关系建模和逐位置嵌入。


BiFormer 结构图:
在这里插入图片描述

3、代码实现

from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor


class TopkRouting(nn.Module):
    def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5
        self.diff_routing = diff_routing
        # TODO: norm layer before/after linear?
        self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
        # routing activation
        self.routing_act = nn.Softmax(dim=-1)

    def forward(self, query: Tensor, key: Tensor) -> Tuple[Tensor]:
        """
        Args:
            q, k: (n, p^2, c) tensor
        Return:
            r_weight, topk_index: (n, p^2, topk) tensor
        """
        if not self.diff_routing:
            query, key = query.detach(), key.detach()
        query_hat, key_hat = self.emb(query), self.emb(key)  # per-window pooling -> (n, p^2, c)
        attn_logit = (query_hat * self.scale) @ key_hat.transpose(-2, -1)  # (n, p^2, p^2)
        topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1)  # (n, p^2, k), (n, p^2, k)
        r_weight = self.routing_act(topk_attn_logit)  # (n, p^2, k)

        return r_weight, topk_index


class KVGather(nn.Module):
    def __init__(self, mul_weight='none'):
        super().__init__()
        assert mul_weight in ['none', 'soft', 'hard']
        self.mul_weight = mul_weight

    def forward(self, r_idx: Tensor, r_weight: Tensor, kv: Tensor):
        """
        r_idx: (n, p^2, topk) tensor
        r_weight: (n, p^2, topk) tensor
        kv: (n, p^2, w^2, c_kq+c_v)

        Return:
            (n, p^2, topk, w^2, c_kq+c_v) tensor
        """
        # select kv according to routing index
        n, p2, w2, c_kv = kv.size()
        topk = r_idx.size(-1)
        # print(r_idx.size(), r_weight.size())
        # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
        topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1),
                               # (n, p^2, p^2, w^2, c_kv) without mem cpy
                               dim=2,
                               index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv)
                               # (n, p^2, k, w^2, c_kv)
                               )

        if self.mul_weight == 'soft':
            topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv  # (n, p^2, k, w^2, c_kv)
        elif self.mul_weight == 'hard':
            raise NotImplementedError('differentiable hard routing TBA')

        return topk_kv


class QKVLinear(nn.Module):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)

    def forward(self, x):
        q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim + self.dim], dim=-1)
        return q, kv


class BiLevelRoutingAttention(nn.Module):
    """
    这个模块要求的输入格式为 (N,H,W,C)
    """
    def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                 kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
                 topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False,
                 side_dwconv=3,
                 auto_pad=False):
        super().__init__()
        # local attention setting
        self.dim = dim
        self.n_win = n_win  # Wh, Ww
        self.num_heads = num_heads
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads == 0, 'qk_dim and dim must be divisible by num_heads!'
        self.scale = qk_scale or self.qk_dim ** -0.5

        self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv // 2,
                              groups=dim) if side_dwconv > 0 else \
            lambda x: torch.zeros_like(x)

        self.topk = topk
        self.param_routing = param_routing
        self.diff_routing = diff_routing
        self.soft_routing = soft_routing
        # router
        assert not (self.param_routing and not self.diff_routing)  # cannot be with_param=True and diff_routing=False
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.scale,
                                  topk=self.topk,
                                  diff_routing=self.diff_routing,
                                  param_routing=self.param_routing)
        if self.soft_routing:  # soft routing, always diffrentiable (if no detach)
            mul_weight = 'soft'
        elif self.diff_routing:  # hard differentiable routing
            mul_weight = 'hard'
        else:  # hard non-differentiable routing
            mul_weight = 'none'
        self.kv_gather = KVGather(mul_weight=mul_weight)

        self.param_attention = param_attention
        if self.param_attention == 'qkvo':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Linear(dim, dim)
        elif self.param_attention == 'qkv':
            self.qkv = QKVLinear(self.dim, self.qk_dim)
            self.wo = nn.Identity()
        else:
            raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')

        self.kv_downsample_mode = kv_downsample_mode
        self.kv_per_win = kv_per_win
        self.kv_downsample_ratio = kv_downsample_ratio
        self.kv_downsample_kenel = kv_downsample_kernel
        if self.kv_downsample_mode == 'ada_avgpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'ada_maxpool':
            assert self.kv_per_win is not None
            self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
        elif self.kv_downsample_mode == 'maxpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'avgpool':
            assert self.kv_downsample_ratio is not None
            self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
        elif self.kv_downsample_mode == 'identity':  # no kv downsampling
            self.kv_down = nn.Identity()
        elif self.kv_downsample_mode == 'fracpool':
            # TODO: fracpool
            raise NotImplementedError('fracpool policy is not implemented yet!')
        elif kv_downsample_mode == 'conv':
            # TODO: need to consider the case where k != v so that need two downsample modules
            raise NotImplementedError('conv policy is not implemented yet!')
        else:
            raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')

        self.attn_act = nn.Softmax(dim=-1)
        self.auto_pad = auto_pad

    def forward(self, x):

        if self.auto_pad:
            N, H_in, W_in, C = x.size()

            pad_l = pad_t = 0
            pad_r = (self.n_win - W_in % self.n_win) % self.n_win
            pad_b = (self.n_win - H_in % self.n_win) % self.n_win
            x = F.pad(x, (0, 0,  # dim=-1
                          pad_l, pad_r,  # dim=-2
                          pad_t, pad_b))  # dim=-3
            _, H, W, _ = x.size()  # padded size
        else:
            N, H, W, C = x.size()
            assert H % self.n_win == 0 and W % self.n_win == 0  #

        x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)

        q, kv = self.qkv(x)

        q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
        kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
        kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)

        q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean(
            [2, 3])  # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)

        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win,
                                   i=self.n_win).contiguous())
        lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)

        r_weight, r_idx = self.router(q_win, k_win)

        kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix)  # (n, p^2, topk, h_kv*w_kv, c_qk+c_v)
        k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)

        k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)',
                              m=self.num_heads)  # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
        v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c',
                              m=self.num_heads)  # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
        q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c',
                          m=self.num_heads)  # to BMLC tensor (n*p^2, m, w^2, c_qk//m)

        attn_weight = (
                                  q_pix * self.scale) @ k_pix_sel
        attn_weight = self.attn_act(attn_weight)
        out = attn_weight @ v_pix_sel
        out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
                        h=H // self.n_win, w=W // self.n_win)

        out = out + lepe
        out = self.wo(out)

        if self.auto_pad and (pad_r > 0 or pad_b > 0):
            out = out[:, :H_in, :W_in, :].contiguous()

        return out.permute(0, 3, 1, 2)


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    x = x.permute(0, 2, 3, 1)
    model = BiLevelRoutingAttention(512).cuda()
    out = model(x)
    print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

### BRA 注意力机制模块的实现与应用 在神经网络中,BRABidirectional Relation Attention注意力机制模块旨在通过双向关系建模来增强图卷积的效果。该模块不仅考虑节点之间的直接连接,还利用多视图中的间接关联信息。 #### 输入处理 为了有效捕捉不同视图下的节点间复杂依赖关系,输入数据通常由多个邻接矩阵组成,每个矩阵代表一种特定类型的边或关系[^1]: ```python import torch import torch.nn as nn import torch.nn.functional as F class MultiViewAttentionLayer(nn.Module): def __init__(self, input_dim, hidden_dim, num_views): super(MultiViewAttentionLayer, self).__init__() self.num_views = num_views # 定义线性变换层用于映射特征向量至隐藏空间 self.fc = nn.Linear(input_dim, hidden_dim) # 初始化权重参数以计算各视图的重要性分数 self.attention_weights = nn.Parameter(torch.zeros(size=(num_views,))) def forward(self, features_list): """ :param features_list: 列表形式存储的不同视图下节点特征张量 (V * N x D) V 表示视图数量;N 是节点数目;D 为原始特征维度 返回加权后的综合特征表示 """ transformed_features = [] for i in range(len(features_list)): feature_i = features_list[i] # 对每种视图执行相同的线性转换操作 h_i = F.relu(self.fc(feature_i)) transformed_features.append(h_i.unsqueeze(0)) stacked_h = torch.cat(transformed_features, dim=0) # Shape: [V, N, H] # 计算并标准化各个视图对应的注意系数 attention_scores = F.softmax(self.attention_weights, dim=-1).unsqueeze(-1).unsqueeze(-1) attended_representation = torch.sum(stacked_h * attention_scores, dim=0) return attended_representation ``` 此代码片段展示了如何定义一个多视图表注意力层 `MultiViewAttentionLayer` ,其接受来自不同视角的节点属性作为输入,并输出融合了跨视图信息的新特征表示。这里采用了一个简单的全连接层来进行初步降维投影,随后根据不同视图赋予不同的关注度得分完成最终组合。 对于具体应用场景而言,在训练过程中可以将上述自定义层嵌入到更大的架构之中,比如堆叠多层此类组件形成深层模型,或者将其与其他标准 GCN 层交替排列构建混合型框架。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值