第6周学习笔记:Vision Transformer Swin Transformer

一.Vision Transformer

对比ViT(“纯"Transformer模型)、Resnet网络和Hybrid(传统CNN和Transformer混合模型)

1 模型架构

输入一张图片,会把它分成一个一个patches,然后把每个patches输入进Embedding层,然后会得到一个个向量(token),之后在这些token前面加一个class token用于分类,接着需要加上位置信息(Position Embedding),把这一系列token输入到Transformer Encoder中,然后提取class token对应的输出,通过MLP Head得到最终分类的结果

简单而言,模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder(图右侧有给出更加详细的结构)
  • MLP Head(最终用于分类的层结构)

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token,token_dim],

在代码实现中,直接通过一个卷积层来实现。以ViT-B/16为例,使用卷积核大小为16x16,stride为16,卷积核个数为768 ,通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵。

  • 在输入Transformer Encoder之前需要加上[class]token以及Position Embedding,都是可训练参数
  • 拼接[class]token: Cat([1,768],[196,768])-> [197,768]
  • 叠加Position Embedding:[197,768]->[197,768]

如果不使用Position Embedding

 训练得到的位置编码每个位置和其他位置上的余弦相似度,可以看到每个patches和与它在同一行或同一列的patches相似度比较高

Transformer Encoder层

Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由以下几部分组成:

  • Layer Norm 对每个token进行Norm处理
  • Multi-Head Attention
  • Dropout/DropPath,在原论文的代码中是直接使用的Dropout层
  • MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]

 MLP Head层

  • 注意,在Transformer Encoder前有个Dropout层,后有一个LayerNorm
  • 训练lmageNet21K时是由Linear+tanh激活函数+Linear,但是迁移到ImageNet1K上或者你自己的数据上时,只有一个Linear,可以简单理解成一个全连接层

网络流程

假设输入图片是rgb彩色图片,首先通过Embedding(768个16*16卷积+在高度和宽度方向进行展平处理),然后concat一个class token,再加上Position Embedding,通过Dropout层,送入Transformer Encoder,也就是将Encoder Block重复12次,出来后通过LayerNorm,并通过切片提取Class Token对应的输出,最后通过MLP Head得到最后的输出。

2. Hybrid模型

  • R50的卷积层采用的stdConv2d,不是传统的Conv2d
  • 将所有的BatchNorm层替换成GroupNorm层
  • 把stage4中的3个Block移至stage3中

先通过R50 Backbone进行特征提取,得到的特征矩阵shape是[14, 14, 1024],接着再输入Patch Embedding层,注意Patch Embedding中卷积层Conv2d的kernel_size和stride都变成了1,只是用来调整channel。后面的部分和前面ViT中讲的完全一样。

3. 模型代码学习

  • Layers是Transformer Encoder中重复堆叠Encoder Block的次数
  • Hidden Size是通过Embedding层后每个token的dim(向量的长度)
  • MLP size是Transformer Encoder中MLP Block第一个全连接的节点个数(是Hidden Size的四倍)
  • Heads代表Transformer中Multi-Head Attention的heads数

 Multi-Head Attention

class Attention(nn.Module):
    def __init__(self,
                 dim,   # 输入token的dim
                 num_heads=8,  #head数目
                 qkv_bias=False, # 用qkv时是否使用偏置
                 qk_scale=None,
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5 # 不传入就等于根号下dk分之一
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):
        # [batch_size, num_patches + 1(class token), total_embed_dim]
        B, N, C = x.shape

        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]
        # permute调整顺序: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]
        # @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1) # 对每一行进行softmax处理
        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C) # 对权重和v加权求和
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

模型结构

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
       
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.num_tokens = 2 if distilled else 1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
        act_layer = act_layer or nn.GELU

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))# batch维度 1 768
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
        # (batch维度 14*14+1 embed_dim)
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
        # 加上Position Embedding之后的dropout层
        self.pos_drop = nn.Dropout(p=drop_ratio)

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
        # 重复堆叠Encoder Block
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
        self.norm = norm_layer(embed_dim)

        # Representation layer
        if representation_size and not distilled:
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)

        x = self.pos_drop(x + self.pos_embed)
        x = self.blocks(x)
        x = self.norm(x)
        # 提取class token对应的输出
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.head(x) # 全连接层
        return x

二 .Swin Transformer

1网络整体框架

Swin Transformer vs Vision Transformer,能减少计算量,准确率更高

注意,每下采样两倍,channel都会进行翻倍的操作

Patch Partition:首先将图片输入到Patch Partition模块中进行分块,即每4x4相邻的像素为一个Patch,然后在channel方向展平(flatten)

然后就是通过四个Stage构建不同大小的特征图,除了Stage1中先通过一个Linear Embeding层外,剩下三个stage都是先通过一个Patch Merging层进行下采样,都是重复堆叠Swin Transformer Block。最后对于分类网络,后面还会接上一个Layer Norm层、全局池化层以及全连接层得到最终输出。

2 Patch Merging

Patch Merging操作首先会将临近2*2范围内的patch拼接起来,然后将每个patch中相同位置(同一颜色)像素拼在一起就得到4个feature map。接着将这四个feature map在深度方向进行concat拼接,然后在通过一个LayerNorm层。最后通过一个全连接层在feature map的深度方向做线性变化,将feature map的深度由C变成C/2(对于每个patch而言,维度由C上升至2C)。

然后该特征送入几个Transformer Block中,得到Stage 2。经过Stage 2,特征图变为原图的1/8(H / 8,W / 8)。以此类推,得到Stage 3 (H / 16, W / 16)和 Stage 4(H / 32,W / 32)。

3 W-MSA(Windows Multi-head Self-Attention)

普通MSA:对每个块求qkv,特征矩阵中的每个像素,都会和特征矩阵所有的像素进行信息的沟通。

W-MSA:首先将feature map按照MxM(例子中的M=2)大小划分成一个个Windows,然后单独对每个Windows内部进行Self-Attention。这样的目的是为了减少计算量,但是会导致窗口之间无法进行信息交互

计算量:

  • h代表feature map的高度
  • w代表feature map的宽度
  • C代表feature map的深度
  • M代表每个窗口(Windows)的大小

 假设feature map的h、w都为112,M=7,C=128,采用W-MSA模块相比MSA模块能够节省约40124743680 FLOPs:

4 SW-MSA

采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了Shifted Windows Multi-Head Self-Attention(SW-MSA)模块,即在第L+1层使用进行偏移的W-MSA。偏移后的窗口,比如对于第一行第2列的2x4的窗口,它能够使第L层的第一排的两个窗口信息进行交流。

但是窗口移位后,会出现一些小块的窗口小于M*M,另外从4个窗口增加到9个窗口,计算量也增加了。

native method:填充到所需要的大小,并且在计算的时候忽略这些填充的部分——计算量太大了

Batch Computation Approach:向左上方循环移动,移位后,一个窗口可能由几个在特征映射中不相邻的子窗口组成(实现了把自注意限制在子窗口内)

另外,为了防止把不同的区域合并在一起可能会导致的信息乱窜,在实际计算中使用的是masked MSA即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息了。

可以将像素0与区域3中的所有像素匹配结果都减去100,由于α的值都很小,一般都是零点几的数,将其中一些数减去100后在通过SoftMax得到对应的权重都等于0了

注意,在计算完后需要把数据给挪回到原来的位置上

5 相对位置偏执

相对位置编码主要是为了解决Self-Attention中的排列不变性的问题,即不同顺序输入的tokens会得到一样的结果。

相对位置索引

作者将2元坐标转化为1元坐标:偏移从0开始,行、列标加上M-1(例子中M是2),接着将所有的行标都乘上2M-1,最后将行标和列标进行相加。这样即保证了相对位置关系,而且不会出现上述0+(−1)=(−1)+0的问题了

真正使用到的可训练参数B是保存在relative position bias table表里的,这个表的长度是等于(2M−1)×(2M−1)的。那么公式中的相对位置偏执参数B是根据上面的相对位置索引表根据查relative position bias table表得到的,如下图所示。

三. ConvNeXt

ConvNeXt的提出强行给卷积神经网络续了口命。ConvNeXt使用的全部都是现有的结构和方法,没有任何结构或者方法的创新。

作者首先利用训练vision Transformers的策略去训练原始的ResNet50模型,发现比原始效果要好很多,并将此结果作为后续实验的基准

1. Macro design

Changing stage compute ratio:在Swin Transformer中,stage3堆叠block的占比更高,于是作者将ResNet50中的堆叠次数由(3, 4, 6, 3)调整成(3, 3, 9, 3)

Changing stem to “Patchify”:作者将ResNet中的stem也换成了和Swin Transformer一样的patchify

 2. ResNeXt

作者借鉴了ResNeXt中的组卷积grouped convolution,因为ResNeXt相比普通的ResNet而言在FLOPs以及accuracy之间做到了更好的平衡。作者采用group数和通道数channel相同的depthwise convolution,这样做的另一个原因是作者认为depthwise convolution和self-attention中的加权求和操作很相似。

3. Inverted Bottleneck

作者认为Transformer block中的MLP模块非常像MobileNetV2中的Inverted Bottleneck模块,即两头细中间粗。下图a是ReNet中采用的Bottleneck模块,b是MobileNetV2采用的Inverted Botleneck模块,c是ConvNeXt采用的是Inverted Bottleneck模块

 4. Large Kernel Sizes

Moving up depthwise conv layer:因为在Transformer中,MSA模块是放在MLP模块之前的,所以这里进行效仿,将depthwise conv上移

Increasing the kernel size:将depthwise conv的卷积核大小由3x3改成了和Swin Transformer一样的7x7

5.Micro Design

更改一些更细小的差异

Replacing ReLU with GELU:在Transformer中激活函数基本用的都是GELU,但替换后发现准确率没变化

Fewer activation functions:在Transformer中并不是每个模块后都跟有激活函数,比如MLP中只有第一个全连接层后跟了GELU激活函数。所以作者在ConvNeXt Block中也减少激活函数的使用

Fewer normalization layers:在Transformer中,Normalization使用的也比较少,接着作者也减少了ConvNeXt Block中的Normalization层

Substituting BN with LN:在Transformer中基本都用的Layer Normalization,(因为最开始Transformer是应用在NLP领域的,BN不适用于NLP相关任务)。接着作者将BN全部替换成了LN,发现准确率还有小幅提升

Separate downsampling layers:与Swin Transformer看齐,使用单独的下采样层

6.ConvNeXt variants

对于ConvNeXt网络,作者提出了T/S/B/L四个版本,计算复杂度刚好和Swin Transformer中的T/S/B/L相似。

 ConvNeXt-T 结构图

每个stage都是由ConvNeXt Block组成

感想

在阅读论文、学习模型的过程中,要学会利用已有的模型和经验,参考优秀模型的改进思想。

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值