ViT和SwinTransformer详解

ViT是Google brain发表于ICLR'21上的工作,开创性将transformer用在vision领域,且图像识别性能超CNN,至今引用3.8w+;原文:https://arxiv.org/pdf/2010.11929

SwinTransformer是微软亚洲研究院发表于ICCV'21上,获best paper,在多个视觉任务上获sota,打破CNN垄断vision backbone的现状,至今引用1.8w+;原文:https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_Swin_Transformer_Hierarchical_Vision_Transformer_Using_Shifted_Windows_ICCV_2021_paper.pdf

建议读原文,这些文章优雅、简洁、深刻。

下面按照三部分进行,分别是Attention介绍、ViT详解、SwinTransformer详解。与常规文章讲解不同,我会多采用QA进行展开。

1. Attention介绍

这涉及到NeurIPS发表的“Attention is all you need”,这篇文章引用已经12w+,理解注意力机制是学习transformer的核心。

Q: general attention和self attention区别?

A: 相同点是均需要计算qkv,不同之处,self attention的input只有x,而general attention的input除了有x(映射得到kv),还有q(查询query)。

Self attention layer介绍

步骤:1. 输入x,通过映射矩阵Wq,Wk,Wv,得到qkv(D维)

           2. q和k进行对齐操作,如:q0会分别与不同的k进行点乘操作,得e0(e矩阵第一列)

           3. 注意力机制:softmax操作。如,从e0得a0(a矩阵第一列),为0~1之间的注意力权重

           4. 输出:v和注意力权重a的加权和,如:y0为a0和所有v的注意力加权和

=》不同于CNN的局部特性,此处的自注意力很好地体现了全局特性。

仔细观察,可以发现self-attention layer具备permutation invariant的性质(置换不变)

现实中,不管是语言token还是vision patch token,位置不同,显然我们应该得到不同的内容向量y才是合理的。

因此,有必要加入位置编码,将位置信息考虑进来进行自注意力学习。

对于每个输入xj,给出位置编码pj。使用位置编码函数pos,pj=pos(j),将位置j映射到D维向量(因为x是D维)。对于pos函数的选取此处不详细展开。

Multi-head self attention layer介绍

多头自注意力层就是transformer里核心模块。

Q:为什么要multi-head?

A:本质是为提取更好的特征,类似于CNN中卷积核也是多组,以得到多个特征谱。不同的是,CNN中卷积核小,计算量小,特征谱数量都是几十、几百。这儿的Multi-head不会很多,一般不超过10。

2. ViT详解

这篇文章的writing也可以当作范本,反复学习。

Q:标题两个keys,一个是an image is worth 16x16 words, 另一个是at scale,分别突出了什么?

A:前者突出将图像按照文字的处理方式,把一张图表示成了16x16 tokens。另一个关键点at scale,则与transformer的优势关联起来,也暗含了transformer要获的良好性能的前提。

Q:Transformer的天然优势是什么?

A:主要是excellent scalability,当模型和训练集增加时,并没有saturating performance。可以处理超大规模的训练数据。另一个是self-attention带来的computational efficiency,很多计算可以高度并行。

Q:CNN的天然优势是什么?

A:主要是inductive bias,在卷积的过程中,我们使用了translation equaivariance(平移不变性)、locality(局部性)来保留2D相邻结构。这些使得CNN在少量训练数据时候也能获得很好的性能。

Q:什么时候Transformer会比CNN更好?

A:通常,小训练数据集时候,convolutional inductive bias会很有用。当,数据集规模足够大的时候,最终large scaling training会比inductive bias表现好。这是合理的,因为Transformer学习中没有inductive bias,其特征时只能从大规模数据中学习。

Conclusion中提及的几个有前瞻性的点,现在均已经实现:)

1)self-supervised vs. supervised learning,之间的gap已经去掉;

2)scaling law,随着scaling提升,模型性能提升,现在已经是大模型发展遵循的发展规律;

3)transformer在segmentation、detection上的发展,现在已经横扫这些视觉任务。

# Transformer Encoder (depth x)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head,dropout=dropout),
                FeedForward(dim,mlp_dim,dropout=dropout)]))
    
    def forward(self, x):
        for att, ff in self.layers:
            x = attn(x)+x
            x = ff(x)+x
        return self.norm(x)
# Multi-Head Attention
# 与self-attention layer中的operation保持一致

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropot=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not(heads==1 and dim_head==dim)

        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)    
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)
        
        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q,k,v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) # batch(b), sequence length(n), heads(h), dim(d) 
        dots = torch.matmul(q, k.transpose(-1,-2))*self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
    
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)
# FeedForward (Transformer Encoder第二个部分)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout))
        
        def forward(self, x):
            return self.net(x)
# ViT framework
# 包括输入图像的处理方式以及具体的任务

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
        
        num_pathes = (image_height//patch_height)*(image_width//patch_width) # sequence length
        patch_dim = channels*patch_height*patch_width
        assert pool in {'cls','mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' #输入序列的第一个位置会添加一个特殊的标记,称为 [CLS] 标记

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1,1,dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes) # classification task

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b,n,_ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:,:(n+1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool=='mean' else x[:,0]

        x = self.to_latent(x)
        return self.mlp_head(x)  

        

此处,pos_embedding是随机给的,transformer的输出后pool只能选cls或者mean中之一,然后进行MLP对任务的预测。

这里没有涉及到transformer decoder设计。

3. SwinTransformer详解

Q:ViT不好吗,SwinTransformer主要解决哪些关键问题?

A:如果图像分辨率变大,按照patch的size进行切分,这时候图像块的数量会增加,相应的计算复杂度quadratic增加,除此外切分的patch也相对较大(下采样倍数高),特征提取信息不准。不能很好处理高分辨率图。除此外,ViT这种固定的图像块切分方法对于不同大小的视觉实体而言不是很合理,当物体远小于或者大于patch时候,很难有效提取特征。不同物体的尺寸和比例差异很大,不像单词的长度相对固定。不能很好处理大小变化的视觉物体

个人感觉SwinTransformer中窗口概念,类似与CNN中卷积核,窗口shift类似于CNN中stride,不同这个shift(向右、向下)更加灵活。不同在于,CNN针对局部图像感受野直接去求W,而SwinTransformer则是利用Self-attention更高级的方式去求局部图像的特征。

很多技巧都是用于减少(分窗口、窗口移动)运算量。

Q:SwinTransformer主要贡献?

A:第一,层级的特征谱方式使计算复杂度对于图像尺寸而言是linear而不是quadratic,可以处理高分辨率的图像。第二、shifted window很好解决了视觉物体大小变化的特点。

Q:SwinTransformer主要设计思想?

A:全局的注意力机制只在小范围内做,然后在不同层级上提特征(W-MSA,提出窗口的概念,窗口内进行多头注意力机制)。此外,利用shifted window将各个窗口之间的信息进行通信,完美达到捕获全局的上下文信息的优势(SW-MSA,此处就是滑动窗口的多头注意力机制)。这两部分就是Swin Transformer blocks的主要组成部分。

对比:

1)SwinTransformer有很多窗口(红色框),且在不同的层级上,窗口的划分是不同的。ViT将整图作为一个窗口,一直进行全局注意力机制计算。

2)SwinTransformer先进行4x下采样,将4*4个pixels作为一个小patch,在划定的窗口内进行注意力计算,然后是8x下采样,最后是16x下采样。ViT直接下采样16倍,后面保持相同的下采样规律。

SwinTransformer Blocks

1. W-MSA介绍(窗口间不涉及信息传递)

这个提出的目的是在窗口内进行kqv的求解,既能减少计算复杂度,也能使用更小的patch size,使下采样倍数不用很大。

Q:具体减少了多少运算量?

A:运算量主要分为三部分:1)to kqv,2)qk对齐,3)与v加权和。Att(Q,K,V)=Softmax(\frac{QK^{T}}{\sqrt{d}})V

1) X^{hw\times C}通过矩阵运算W^{C\times C}生成Q^{hw\times C },K^{hw\times C},V^{hw\times C}.总运算量为3hwC^{2}

2) qk对齐,运算量为(hw)^{2}C,得A^{hw\times hw}

3) 与v加权,得B^{hw\times C},运算量为(hw)^{2}C

4) 多头注意力机制,多了一个融合矩阵W,B^{hw\times C}\cdot W^{C\times C}=O^{hw\times C},计算量hwC^{2}

总计,4hwC^{2}+2(hw)^{2}C 公式一

假设W-MSA的窗口长和宽为M,代入上面公式为,

4M^{2}C^{2}+2M^{4}C

\frac{h}{M}\times \frac{w}{M}窗口,所以为,4hwC^{2}+2M^{2}hwC 公式二

缺点:减少了运算量,但窗口之间由于没有任何通信,导致确实全局感受野。

2.SW-MSA介绍(窗口间进行信息传递)

两层之间发生了窗口的移动(Shift),偏移的量是:往右、往下偏移M/2个像素。移动后,划分出的第二列3个窗口能够完成相邻窗口的信息交流。

缺点:原来4个窗口,移动后变成9个窗口,且大小不一。总之,移动后窗口的数量增多,从\frac{h}{M}\times \frac{w}{M}变成(\frac{h}{M}+1)\times (\frac{w}{M}+1),有些窗口会变小。

解决办法:

naive方案,把所有变小的窗口pad后,计算attention时候把pad数值掩膜。但这样,存在很多没必要的运算。

Efficient batch computation approach by cyclic-shifting toward the top-left direction

主要思想:将移动模式作为flag,只对有相邻关系的子窗口计算,不相邻的,减去100,使得softmax计算后概率接近0。

示意图传递了跟之前的窗口类似的计算,对于不相关的信息加上了mask,softmax后得到的概率接近0,使其达到mask的作用。最后再通过reverse cyclic shift移动回去。

SwinTransformer Framework

整体架构的思想:4个阶段,每个阶段构建不同大小的特征图,不断缩小分辨率,类似CNN逐渐增大感受野。

  • Patch partition, 本质就是矩阵的reshape,以4*4为一个图像块,对输入图片进行分块,然后在channel方向上进行拼接
  • Linear embedding, 经过线性变换,通道数从48变成C
  • Patch merging, 本质就是降采样,只是比pooling的方式来得更复杂一些,有学习的参数

关于relative position bias这里不展开,因为其在图像分类上提高了,但是在目标检测任务上降低了,具体理解可以参考[4]。

Application:下一篇会介绍SwinIR,揭开该方法如何在底层视觉的图像修复上施展魔法。

参考:

[1] cs231n课件

[2] vit-pytorch/vit_pytorch/vit.py at main · lucidrains/vit-pytorch · GitHub

[3] Swin Transformer:屠榜各大CV任务的视觉Transformer模型 (high-level介绍)

[4] Swin Transformer 详解(detail-level理解,很不错)

  • 29
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值