狗都能看懂的Vision Transformer with Deformable Attention的讲解和代码实现

1、前言

在前面一篇博客介绍了可变形卷积,相比于普通的卷积,它自带的可学习偏移量使得模型能够关注感兴趣的区域,而不是固定的窗口。这个想法是在CNN中使用的,近些年来ViT的爆火,Self-Attention机制带来了更高的准确率。Transformer模型通过大接收场在视觉任务中展现了强大的表现力。但这种大接收场也带来了高昂的计算和内存成本,同时可能受到无关区域的干扰。随后Swin-Transformer的出现有效减少了计算量。不过其提出的MSA模块,虽然可以在一定程度上控制计算复杂度,但由于接收场的增长较慢,限制了对大物体建模的潜力。

原论文地址: https://arxiv.org/abs/2201.00520
官方开源代码:https://github.com/LeapLabTHU/DAT
Pytorch实现代码:https://github.com/Runist/DAT

2、Vision Transformer with Deformable Attention

作者参考Deformable Convolution提出了一种新颖的可变形自注意力模块(DMHA),这种模块的关键(key)和值(value)对的位置选择是数据相关的。通过这种灵活的机制,模型能够专注于重要的区域,从而捕捉到更多有用的特征,同时通过下采样,减少无关的计算。

文中会对Swin-TransformerViT进行比较,没有看过的同学建议先阅读一下。

3、模型结构

模型还是经典结构,由4个Stage组成,这里和ViT以及Swin-Transformer不同,每个Stage用到的Attention模块不同,后面会讲解原因。

DAT

  • Patch Embedding用的是Conv + LayerNorm层,和ViT的token转换是一样的。这里就不多说了Vision Transformer的博客中有详细的讲解
  • Stage 1 与 Stage 2先是一个W-MSA模块,再接另一个SW-MSA模块,这两个模块是成对出现的的,和Swin-Transformer是一摸一样。本文就不再展开讲解,请阅读Swin-Transformer的博客。
  • Stage 3 和 Stage 4则是采用W-MSA + MDHA堆叠而成。这里MDHA就是DAT中提到Deformable attention module模块。
  • 对于分类网络在代码中,还有LayerNormAvgPooling和一个全连接层组成,这个在图中没有体现。这个基本已经成了Transformer的不成文规矩了。

4、Deformable attention module详解

在Deformable Convolutional Networks中,每个特征图上的元素单独学习一个偏移量,空间复杂度为HWC,我们还需要考虑变形卷积的kernel大小,如果是3x3的卷积,空间复杂度需要再乘以3*3 = 9HWC。如果将这一机制直接应用于注意力模块,复杂度将急剧上升到 N q N k C N_qN_kC NqNkC,其中 N q N_q Nq N k N_k Nk分别是query和key的数量,而这个数量同常是和特征图的尺寸一样,所以直接套用上去会造成4次方的复杂度。

Deformable DETR也应用了类似的方法,但他们通过设置较少的key( N k = 4 N_k=4 Nk=4)减少开销。但这只在检测头中表现良好,在主干网络中,因为key的数量过少而导致信息丢失问题比较严重。

DA-Offset.png

本文提出的一个更简单的解决方案,为每个query共享key和value,如上图所示。

我们逐步讲解一下:

  1. x x x为特征图,先生成Reference Points。图上只画了4个点(实际不止4个,为了简化)。
  2. x x x重映射为 q q q,通过Offset network生成偏移量,生成偏移量的数量和Reference Points一致。Offset network内可以通过控制downsample factor r r r,来控制生成的特征图大小。
  3. Reference Points加上offsets得到中间的Features,为了减少计算量,我们会进行一次下采样。这里会根据通道数分成多个groups,类似Multi-Head Self-Attention的思想,增强不同groups之间特征多样性。
  4. 生成对应的 v v v k k k,计算出attn值,得到output。

在实际代码中,步骤3与步骤4之间还有位置编码偏执,文章里只是简单提及了一下。源码中有具体的实现方法,简单来说是生成了一个位置编码表(可学习的),利用上图的Bilinear Interpolation进行下采样,得到和attn的shape一致的偏执。

看到这里,有的同学可能会问,这个offset的网络,如果在r=1的情况下,那岂不是计算量没有减少。并且本身ViT也是全局token之间互相计算attention,那这个Deformable attention module岂不是只增加了计算量?其实我刚开始看代码的时候也有这个疑问,但后面看了下生成的offset值,有些位置是重复的,有些位置是空的。也就意味着,模型会自己关注感兴趣的区域,重复的区域就代表权重大,位置为空,权重小。所以结合Deformable Convolution的思想,就有了Deformable multi-head attention(DMHA)。并且通过控制 r r r,能够有效降低计算量,在高分辨率输入的任务比较有用。

具体代码如下:

    def forward(self, x):

        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device
        
        # proj_q is weight_q, just conv
        q = self.proj_q(x)      # [B, C, H, W] => [B, C, H, W]

        # 'b (g c) h w':表示原始张量的维度。这里,(g c) 表示一个维度,它实际上是 g 和 c 这两个维度的乘积
        # (b g) c h w:将 b 和 g 维度合并成一个新的维度
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)     # [B, C, H, W] => [B * g, C / g, H, W]
        
        # Offset network in paper
        offset = self.conv_offset(q_off)    # [B * g, C / g, H, W] => [B * g, 2, H, W]

        # H = Hk, W = Wk
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk
        
        if self.offset_range_factor > 0:
            offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
            # tanh [-1, 1]
            # mul(offset_range) [-1 / Hk, 1 / Wk]
            # mul(offset_range) [-offset_range_factor / Hk, offset_range_factor / Wk]
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')             # [B * g, 2, H, W] => [B * g, H, W, 2]
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)          # [B * g, Hk, Wk, 2] 

        if self.no_off:
            offset = offset.fill(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            # To stabilize the training process
            pos = (offset + reference).tanh()

        # Bilinear Interpolation in paper
        x_sampled = F.grid_sample(
            input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),  # [B, C, H, W] => [B * g, C / g, H, W]
            grid=pos[..., (1, 0)], # y, x -> x, y
            mode='bilinear', align_corners=True)
        
        x_sampled = x_sampled.reshape(B, C, 1, n_sample)    # [B * g, C / g, H, W] => [B, C, 1, H * W]

        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)                            # [B, C, 1, H * W] => [B * nh, C / nh, H * W]
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)    # [B, C, 1, H * W] => [B * nh, C / nh, H * W]
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)    # [B, C, 1, H * W] => [B * nh, C / nh, H * W]

        # m 和 n指代两个tensor的n_sample
        attn = torch.einsum('b c m, b c n -> b m n', q, k) # [B * nh, H * W, H * W]
        attn = attn.mul(self.scale)

        if self.use_pe:

            if self.dwc_pe:
                # Depth-wise Convolutional Position Encoding
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
            elif self.fixed_pe:
                # Fixed Position Encoding
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, self.n_sample)
            else:
                # Relative Position Bias
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                
                q_grid = self._get_ref_points(H, W, B, dtype, device)

                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)

                attn_bias = F.grid_sample(
                    input=rpe_bias.reshape(B * self.n_groups, self.n_group_heads, 2 * H - 1, 2 * W - 1),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True
                )

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        # m 和 n都是n_sample
        out = torch.einsum('b m n, b c n -> b c m', attn, v)    # [B * nh, C / nh, H * W]

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)        # [B * nh, C / nh, H * W] =>  [B, C, H, W]
        
        y = self.proj_drop(self.proj_out(out))
        
        return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

5、Deformable attention module计算量

DMHA相比于Swin-Transformer和PVT只多了一个offset networks的计算。文章给出了复杂度的计算公式:
Ω ( D M H A ) = 2 H W N S C + 2 H W C 2 + 2 N s C 2 + ( k 2 + 2 ) N s C \Omega(DMHA) = 2HWN_SC + 2HWC^2 + 2N_sC^2 + (k^2+2)N_sC Ω(DMHA)=2HWNSC+2HWC2+2NsC2+(k2+2)NsC

  • H代表feature map的高度
  • W代表feature map的宽度
  • C代表feature map的通道数
  • 其中 N S = H G W G = H W / r 2 N_S = H_GW_G = HW/r^2 NS=HGWG=HW/r2
  • k为DWConv的卷积核数量

其中前三项是Self-attention的固定计算开销,最后一项是offset network的开销。

先看一下最后一项,比较简单。由于DWConv的处理的是下采样之后的Feature map,所以 N s C N_sC NsC对应的其实是 H W / r 2 ∗ C HW/r^2*C HW/r2C r = 1 r=1 r=1时两者相等,k对应的是DWConv的卷积核数量,这没什么好解释。+2对应的是1x1的卷积和LayerNorm层。

再来从通用Transformer推导前三项,首先看一下Self-Attention的公式:
A t t e n t i o n ( Q , K , V ) = S o f t M a x ( Q K T d ) V Attention(Q,K,V)=SoftMax(\frac{QK^T}{\sqrt{d}})V Attention(Q,K,V)=SoftMax(d QKT)V
对于feature map的每个像素(或称为token,patch),都要通过 W q W_q Wq W k W_k Wk W v W_v Wv生成对应qkv。这里假设q,k,v的向量长度与feature map的channel数量C保持一致。那么对应所有像素生成Q的过程如下:
X H W × C ⋅ W q C × C = Q H W × C X^{HW \times C} \cdot W_q^{C \times C} = Q^{HW \times C} XHW×CWqC×C=QHW×C

  • X H W × C X^{HW \times C} XHW×C为所有token拼接一起得到的矩阵(一共有HW个像素,每个像素的深度为C)
  • W q C × C W_q^{C \times C} WqC×C为生成query的变换矩阵
  • Q H W × C Q^{HW \times C} QHW×C为所有像素通过 W q C × C W_q^{C \times C} WqC×C得到的query拼接后的矩阵

根据矩阵运算的计算量公式可以得到生成 Q Q Q的计算量为 H W × C × C HW \times C \times C HW×C×C,生成K和V的过程一样,同理都是 H W C 2 HWC^2 HWC2,那么总共是 3 H W C 2 3HWC^2 3HWC2(但这里注意一下,生成q是不用下采样的,只有k和v才需要,所以对应的是公式中的 2 N s C 2 2N_sC^2 2NsC2)。接下来 Q Q Q K T K^T KT相乘,对应计算量为 ( H W ) 2 C (HW)^2C (HW)2C,由于有 r r r的存在,所以实际是 H W N S C HWN_SC HWNSC
Q H W × C ⋅ K T ( C × H W ) = X H W × H W Q^{HW \times C} \cdot K^{T(C \times HW)} = X^{HW \times HW} QHW×CKT(C×HW)=XHW×HW
这里忽略除以 d \sqrt{d} d 以及softmax的计算量,假设得到 A H W × H W A^{HW \times HW} AHW×HW,最后还要乘以 V V V,这里对应的计算量是 ( H W ) 2 C (HW)^2C (HW)2C,由于有 r r r的存在,所以实际是 H W N S C HWN_SC HWNSC
A h w × h w ⋅ V h w × C ) = X h w × C A^{hw \times hw} \cdot V^{hw \times C)} = X^{hw \times C} Ahw×hwVhw×C)=Xhw×C
那么对应单头的Self-Attention模块,总共需要 3 H W C 2 + ( H W ) 2 C + ( H W ) 2 C = 3 H W C 2 + 2 ( H W ) 2 C 3HWC^2 + (HW)^2C + (HW)^2C = 3HWC^2 + 2(HW)^2C 3HWC2+(HW)2C+(HW)2C=3HWC2+2(HW)2C。而在实际使用过程中,使用的是多头的Multi-head Self-Attention模块,在之前的文章中有进行过实验对比,多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵 W O W_O WO的计算量 H W C 2 HWC^2 HWC2

所以总共加起来是: 4 H W C 2 + 2 ( H W ) 2 C 4HWC^2 + 2(HW)^2C 4HWC2+2(HW)2C

由于下采样因子 r r r的存在,DMHA的开销为: H W N S C + H W N S C + 2 N s C 2 HWN_SC + HWN_SC + 2N_sC^2 HWNSC+HWNSC+2NsC2 = 2 H W N S C + 2 N s C 2 2HWN_SC + 2N_sC^2 2HWNSC+2NsC2,再加上一个生成q的计算量 H W C 2 HWC^2 HWC2以及多头注意力模块的最后一个融合矩阵 W O W_O WO的计算量 H W C 2 HWC^2 HWC2
所以总共是 2 H W N S C + 2 N s C 2 + H W C 2 + H W C 2 = 2 H W N S C + 2 H W C 2 + 2 N s C 2 2HWN_SC + 2N_sC^2 + HWC^2 + HWC^2 = 2HWN_SC + 2HWC^2 + 2N_sC^2 2HWNSC+2NsC2+HWC2+HWC2=2HWNSC+2HWC2+2NsC2

现在把 r = 1 r=1 r=1代入到 N S = H G W G = H W / r 2 N_S = H_GW_G = HW/r^2 NS=HGWG=HW/r2,就和 4 H W C 2 + 2 ( H W ) 2 C 4HWC^2 + 2(HW)^2C 4HWC2+2(HW)2C一样。

6、模型详细配置参数

DAT

下图(表1)是原论文中给出的关于不同DAT的配置,T(Tiny),S(Small),B(Base),其中:

  • win. sz. 7x7表示使用的窗口(Windows)的大小
  • N表示堆叠的次数
  • head表示多头注意力模块中head的个数
  • groups表示DMHA中的分组注意力机制的数量

model

需要提一下,为什么只在最后两个Stage中用到Deformable attention module,作者在第三章节说了,是为了实现模型容量和计算负担之间的权衡。但在作者最新的代码中,模型结构已经不是这样了。且不再使用W-MSA模块,SW-MSA模块。笔者根据其以往提交的代码,重新修改后,已和原文对应。

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值