Vision Transformer | CVPR 2022 - Vision Transformer with Deformable Attention

CVPR 2022 - Vision Transformer with Deformable Attention

  • 首先对形状为H×W×3的输入图像进行4×4不重叠的卷积嵌入,然后进行归一化层,得到H4×W4×C 的patch嵌入。
  • 为了构建一个层次特征金字塔,Backbone包括4个阶段,stride逐渐增加。
  • 在2个连续的阶段之间,有一个不重叠的2×2卷积与stride=2来向下采样特征图,使空间尺寸减半,并使特征尺寸翻倍。
  • 前两个阶段主要是学习局部特征,并且其中的key和value尺寸较大。全局计算的Deformable Attention操作并不适合。为了实现模型容量和计算负担之间的权衡,这里采用Local Attention+Shift Window Attention的形式,以便在早期阶段有更好的表示。
  • 在三四阶段引入了Local Attention+Deformable Attention Block的形式。特征图首先通过基于Window的Local Attention进行处理聚合局部信息,然后通过Deformable Attention Block对局部增强后的token之间的全局关系进行建模。

主要改进

在这里插入图片描述

  • Deformable Attention被提出来针对Attention操作引入数据依赖的稀疏注意力
    • 使用q生成一组offsets。按照这些offsets对q进行基于双线性插值的偏移后获得k和v对应的x,之后与q构建MHSA。这里为了稳定训练过程,对偏移值使用一个预定义的因子s来对其进行放缩,从而阻止太大的偏移。 即 Δ p ← s tanh ⁡ ( Δ p ) \Delta p ← s \tanh (\Delta p) Δpstanh(Δp).
    • 这一过程中,同样基于生成的offsets来获得对应的相对位置偏置。它会被加到QK/sqrt(d)上。
    • 为了实现多样的偏移效果,这里同样将特征通道划分为G组,每组使用共享的偏移量。实际中,注意力模块的头数M被设置为偏移组G的倍数。
    • 实际计算中可以通过对输入特征图下采样 r r r被来通过offset进行下采样。文中的模型配置 r r r都设为1。

Attention计算:

在这里插入图片描述
在这里插入图片描述

核心代码

class DAttentionBaseline(nn.Module):
    def __init__(
        self,
        q_size,
        kv_size,
        n_heads,
        n_head_channels,
        n_groups,
        attn_drop,
        proj_drop,
        stride,
        offset_range_factor,
        use_pe,
        dwc_pe,
        no_off,
        fixed_pe,
    ):

        super().__init__()
        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels**-0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        self.kv_h, self.kv_w = kv_size
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor

        if self.q_h == 14 or self.q_w == 14 or self.q_h == 24 or self.q_w == 24:
            kk = 5
        elif self.q_h == 7 or self.q_w == 7 or self.q_h == 12 or self.q_w == 12:
            kk = 3
        elif self.q_h == 28 or self.q_w == 28 or self.q_h == 48 or self.q_w == 48:
            kk = 7
        elif self.q_h == 56 or self.q_w == 56 or self.q_h == 96 or self.q_w == 96:
            kk = 9

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, kk // 2, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False),
        )

        self.proj_q = nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)
        self.proj_out = nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0)
        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w))
                trunc_normal_(self.rpe_table, std=0.01)
            else:
                self.rpe_table = nn.Parameter(torch.zeros(self.n_heads, self.kv_h * 2 - 1, self.kv_w * 2 - 1))
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):
        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key).mul_(2).sub_(1)
        ref[..., 0].div_(H_key).mul_(2).sub_(1)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2
        return ref

    def forward(self, x):
        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, "b (g c) h w -> (b g) c h w", g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off)  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, "b p h w -> b h w p")
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill(0.0)
        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).tanh()

        x_sampled = F.grid_sample(
            input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
            grid=pos[..., (1, 0)],  # y, x -> x, y
            mode="bilinear",
            align_corners=True,
        )  # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)

        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        attn = torch.einsum("b c m, b c n -> b m n", q, k)  # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe:
            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(
                    B * self.n_heads, self.n_head_channels, H * W
                )
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, self.n_sample)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_ref_points(H, W, B, dtype, device)
                displacement = (
                    q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2)
                    - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)
                ).mul(0.5)
                attn_bias = F.grid_sample(
                    input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
                    grid=displacement[..., (1, 0)],
                    mode="bilinear",
                    align_corners=True,
                )  # B * g, h_g, HW, Ns
                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)
        out = torch.einsum("b m n, b c n -> b c m", attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)
        y = self.proj_drop(self.proj_out(out))
        return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

实验结果

在这里插入图片描述

消融实验

利用空间信息的形式的对比以及位置信息的引入方式Deformable Attention 的使用位置Deformable Attention offset的范围因子(代码中可见)
在这里插入图片描述在这里插入图片描述在这里插入图片描述

对比实验

分类目标检测分割
在这里插入图片描述在这里插入图片描述在这里插入图片描述

在这里插入图片描述

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值