[Swin Transformer] Swin Transformer: HierarchicalVision Transformer using Shifted Windows

学习 专栏收录该内容
29 篇文章 0 订阅
image-20210403104347919

1. Motivation

将transformer从NLP应用于CV领域存在以下2个方面的挑战,图像尺度的多样性,以及图像像素相对于words的高分辨率,这会造成内存大的花销。

Challenges in adapting Transformer from language to vision arise from differences between the two domains, such as large variations in the scale of visual entities and the high resolution of pixels in images compared to words in text.

本文中作者希望扩大transformer的适用范围,使得它可以作为一个通用的backbone,应用于CV领域。

In this paper, we seek to expand the applicability of Transformer such that it can serve as a general-purpose backbone for computer vision, as it does for NLP and as CNNs do in vision.

图1表示了VIT与本文提出的Swin Transformer的区别。

image-20210403105152384

2. Contribution

为了解决以上2种问题,作者提出了Swin Transformer,用于构建分层hierarchical的,移动窗口的视觉transformer backbone。并且在目标检测,实例分割,图像分类,语义分割上都取得SOTA的结果。

3. Method

3.1 Overall Architecture

如图3,Swin Transformer的整体结构,其中包含了VIT中的patch partition linear embedding以及patch merging,所不同的是提出了一个Swin Transformer Block,其中包含了W-MSA,以及基于shifted的SW-MSA,外加一些LN和MLP,残差的常规操作。

image-20210404111700566

3.2 Shifted Window based Self-Attention

3.2.1 Self-attention in non-overlapped windows

MSA与W-MSA之间的时间复杂度比较。

image-20210405105826208

3.2.2 Shifted window partitioning in successive blocks

基于窗口的self-atte缺乏窗口之间的联系,因此为了得到cross-window的联系,同时保留高效的非重叠的窗口计算代价,作者引入了shifted window partition 方法。

image-20210403110016250

successive transformer block中的W-MSA以及SW-MSA的公式如下:

image-20210405112056879

3.3. Efficient batch computation for shifted configuratio

由于移动窗口之后,windows的数量会从h/M x w/M变为(h/M+1) x (w/M+1),并且有些windows的大小是小于MxM的,作者为了减少计算开销,提出了一个更高效的方法,如图4所示,cyclic-shifting towards the topleft direction。这样子做每个batched window(理解为整个大window)由feature map中不邻近(adjacent)的sub-windows组成,然后使用masking mechanism来限制每一个sub-window的self-attn,这样子就不用将小的windows pad填充为MxM的windows了。

the number of batched windows remains the same as that of regular window partitioning,

image-20210405113537092

3.3.4 Relative position bias

在计算self-atte时,作者加入了 relative position bias, Q , K , V ∈ R M 2 × d , Q,K,V \in R^{M^2 \times d}, Q,K,VRM2×d, B ∈ R M 2 × M 2 B \in R^{M^2 \times M^2} BRM2×M2,来计算每一个head之间的相似性,如公式4所示:

image-20210405115045215

因此,相对坐标的范围为:[-M+1, M-1],(跟relative position embedding那篇文章是一样的思路),注意这里是类似算一个(x,y),例如,对于(0,0)这个pixel,(3,3)到原点的距离为(3,3),也就是没有将二维坐标flat操作,因此构建出realtive position bias matrix B ^ ∈ R ( 2 M − 1 ) × ( 2 M − 1 ) \hat B \in R^{(2M-1)\times (2M-1)} B^R(2M1)×(2M1)。代码示例如下:

def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in 		                                        range(window_size)])) # [window_size ** 2, 2]
    distances = indices[None, :, :] - indices[:, None, :] # [window_size ** 2 , window_size ** 2, 2]
    return distances



class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
		....
        ....
        if self.relative_pos_embedding:
            self.relative_indices = get_relative_distances(window_size) + window_size - 1 # 归一化为正数 [value+ M-1]  max value add  from (M-1)  to (M-1 + M-1), i.e. 2M-1 in one axis, [ 2M-1, 2M-1] in [x,y] axis. 
            self.pos_embedding = nn.Parameter(torch.randn(2 * window_size - 1, 2 * window_size - 1)) # [2M-1, 2M-1]
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))

        self.to_out = nn.Linear(inner_dim, dim)

4. Experiments

4.1. Image Classification on ImageNet-1K

image-20210405120533206

4.2. Object Detection on COCO

image-20210405120757077
image-20210405120925711

4.3. Semantic Segmentation on ADE20K

image-20210405121113813

4.4. Ablation Study

4.4.1 Shifted Windows Relative position bias

image-20210405121150340
image-20210405121301577

### 4.4.2 Different self-attention methods

image-20210405121308619
  • 0
    点赞
  • 0
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值