Segformer 论文笔记

SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers

SegFormer

论文链接: https://arxiv.org/abs/2105.15203
代码链接: https://github.com/NVlabs/SegFormer
Demo链接: https://www.bilibili.com/video/BV1MV41147Ko/

一、 Problem Statement

SETR 采用了 ViT 作为backbone,然后与CNN的decoders结合来扩大特征图的分辨率。但是ViT有以下缺点:

  1. ViT的输出是single-scale low-resolution的特征图,而不是multi-scale的。
  2. 计算量大。

对于第一个缺点, (pyramid vision transformer)PVT 基于 ViT提出了金字塔结构。但是,当PVT结合Swin Transformer和Twins的时候,这些方法主要考虑的是Transformer encoder,忽略了对decoder的进一步改进和提升。

二、 Direction

基于效率,精度和鲁棒性,作者提出了以下两个方向:

  1. positional-encoding-free and hierarchical Transformer encoder
  2. lightweight ALL-MLP decoder

所提出的encoder避免插入位置信息,对于训练和测试不同分辨率的输入,不会影响性能。除此之外,hierarchical部分可以使得encoder产生高分辨率精细的特征图和低分辨率粗糙的特征图。其次,提出的lightweight ALL-MLP decoder是Transformer产生的特征,其中较低层的注意力往往停留在局部,而最高层的注意力是非局部的。通过融合不同层的信息,MLP encoder结合了局部和全局的注意力。

三、 Method

先来看一下整体的网络结构:
在这里插入图片描述

可以看到,网络主要分两个部分:

  1. A hierarchical Transformer encoder
  2. A lightweight ALL-MLP decoder

输入一张图片 H × W × 3 H \times W \times 3 H×W×3,首先会分成size为 4 × 4 4 \times 4 4×4的patches。使用小一点的patches size有益于dense prediction task。 然后使用这些patches作为输入,输入到hierarchical Transformer encoder里面来获得多层级的特征图,分别为原图的 { 1 / 4 , 1 / 8 , 1 / 16 , 1 / 32 } \{1/4, 1/8, 1/16, 1/32\} {1/4,1/8,1/16,1/32}大小。然后把这些多层级的特征图输入到ALL-MLP decoder中,预测大小为 H 4 × W 4 × N c l s \frac{H}{4} \times \frac{W}{4} \times N_{cls} 4H×4W×Ncls的segmentation mask,其中 N c l s N_{cls} Ncls是类别的数量。

1. Hierarchical Transformer Encoder

Encoder由图示模块所堆叠起来: 包括Efficient Self-Attention,Mix-FFNOverlap Patch Embedding三个模块。

  • Overlap Patch Embedding
    给定一个大小为 H × W × 3 H \times W \times 3 H×W×3的输入,使用patch merging来获得多层级的feature map F i F_i Fi,对应的分辨率大小为: H 2 i − 1 × W 2 i + 1 × C i , i ∈ { 1 , 2 , 3 , 4 } \frac{H}{2^{i-1} } \times \frac{W}{2^{i+1} } \times C_i, \quad i \in \{1,2,3,4\} 2i1H×2i+1W×Ci,i{1,2,3,4},其中通道数 C i + 1 C_{i+1} Ci+1大于 C i C_i Ci。 Overlapped Patch merging操作会把 N × N × 3 N \times N \times 3 N×N×3的输入变为 1 × 1 × C 1\times 1\times C 1×1×C的vector。不太了解的话,我们看以下源代码一些片段:

    
    
    class MixVisionTransformer(nn.Module):
        def __init__(self, ...):
            super().__init__()
            self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, 
                                                  in_chans=in_chans, embed_dim=embed_dims[0])
    
            self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, 
                                                  in_chans=embed_dims[0], embed_dim=embed_dims[1])
    
            self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2,
                                                  in_chans=embed_dims[1], embed_dim=embed_dims[2])
    
            self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, 
                                                  in_chans=embed_dims[2], embed_dim=embed_dims[3])
    
    
    
    class OverlapPatchEmbed(nn.Module):
        def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
            super().__init__()
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
            self.norm = nn.LayerNorm(embed_dim)
        
    
        def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)
    
        return x, H, W 
    
    

    可以看到,Overlap patch merging操作主要是2D卷积,通过改变patch_size和stride把特征图进行缩放,形成了特征层级结构。ViT中使用的patch merging过程结合的是非重叠的patches,因此会导致无法保留这些patches之间的局部联系。而这里所用的是overlapping patch merging。通过patch_size和stride的不同达到效果

  • Efficient Self-Attention
    encoder的主要计算瓶颈在self-attention层。在原始的multi-head self-attention过程中,每一个heads Q , K , V Q, K, V Q,K,V都是相同的维度 N × C N \times C N×C,其中 N = H × W N = H \times W N=H×W是序列的长度。所以self-attention定义为:
    Attention(Q, K, V) = Softmax ( Q K T d h e a d ) V \text{Attention(Q, K, V)} = \text{Softmax}(\frac{QK^T}{\sqrt{d_{head}}})V Attention(Q, K, V)=Softmax(dhead QKT)V

    这个过程的复杂度是 O ( N 2 ) O(N^2) O(N2),对于分辨率大的图像是难以承担的。所以作者提出了 sequence reduction process。也就是PVT中提出来的spatial reduction操作
    K ^ = R e s h a p e ( N R , C ⋅ R ) ( K ) K = L i n e a r ( C ⋅ R , C ) ( K ^ ) \hat{K} = Reshape(\frac{N}{R}, C\cdot R)(K) \\ K = Linear(C \cdot R, C) (\hat{K}) K^=Reshape(RN,CR)(K)K=Linear(CR,C)(K^)

    简单来说,输入维度为 N × C N\times C N×C的K 首先会reshape成维度为 N R × ( C ⋅ R ) \frac{N}{R} \times (C\cdot R) RN×(CR)维度。然后通过Linear变换,將 ( C ⋅ R ) (C \cdot R) (CR)的维度变为 C C C。这样输出的 K K K维度为 N R × C \frac{N}{R}\times C RN×C,计算复杂度为 O ( N 2 R ) O(\frac{N^2}{R}) O(RN2)。作者把reduction ratio设置为 [ 64 , 16 , 4 , 1 ] [64,16,4,1] [64,16,4,1],分别对应stage-1到stage-4。

    因此,在本文中 Q , K , V Q, K, V Q,K,V的维度分别为 N × C , N R × C , N R × C N \times C, \frac{N}{R} \times C, \frac{N}{R}\times C N×C,RN×C,RN×C。 因此 softmax的维度为: ( Q K T ) V (Q K^T) V (QKT)V,也就是 ( N × C ) × ( C × N R ) × ( N R × C ) = ( N × N R ) × ( N R × C ) = N × C (N \times C) \times (C \times \frac{N}{R}) \times ( \frac{N}{R} \times C) = (N \times \frac{N}{R}) \times (\frac{N}{R} \times C) = N \times C (N×C)×(C×RN)×(RN×C)=(N×RN)×(RN×C)=N×C

    下面来看一下代码节选:

      class Attention(nn.Module):
          def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
              super().__init__()
              assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
    
              self.dim = dim
              self.num_heads = num_heads
              head_dim = dim // num_heads
              self.scale = qk_scale or head_dim ** -0.5
    
              self.q = nn.Linear(dim, dim, bias=qkv_bias)
              self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
              self.attn_drop = nn.Dropout(attn_drop)
              self.proj = nn.Linear(dim, dim)
              self.proj_drop = nn.Dropout(proj_drop)
    
              self.sr_ratio = sr_ratio
              if sr_ratio > 1:
                  self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                  self.norm = nn.LayerNorm(dim)
    
              
          def forward(self, x, H, W):
              B, N, C = x.shape
              q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
    
              if self.sr_ratio > 1:
                  x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
                  x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                  x_ = self.norm(x_)
                  kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
              else:
                  kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
              k, v = kv[0], kv[1]
    
              attn = (q @ k.transpose(-2, -1)) * self.scale
              attn = attn.softmax(dim=-1)
              attn = self.attn_drop(attn)
    
              x = (attn @ v).transpose(1, 2).reshape(B, N, C)
              x = self.proj(x)
              x = self.proj_drop(x)
    
              return x
    
  • Mix-FFN
    ViT使用了position embedding来引入定位信息。但是,position embedding的大小是固定的,因此,当测试的时候输入分辨率与训练的时候的分辨率不一致的话,positional code就需要被插值,导致精度的下降。所以CPV在positional embedding上使用了3X3卷积,来实现数据驱动的一个positional encoding。本文提出的是:positional encoding对于语义分割是没有必要的。 作者引入的是Mix-FFN模块,通过在Feed-forward network(FFN)上直接使用3X3卷积,减弱了zero-padding会丢失一些定位信息的影响。
    x o u t = MLP ( GELU ( Conv 3 × 3 ( MLP ( x i n ) ) ) ) + x i n x_{out} = \text{MLP}(\text{GELU}(\text{Conv}_{3\times 3}(\text{MLP}(x_{in}))))+x_{in} xout=MLP(GELU(Conv3×3(MLP(xin))))+xin

    其中, x i n x_{in} xin是self-attention模块中的输出特征。作者实验表明,3X3卷积足够提供位置信息给Transformer。特别地,作者使用了depth-wise convolutions来减少参数量来提升效率。
    代码节选:

    class Mlp(nn.Module):
        def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
            super().__init__()
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.dwconv = DWConv(hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.drop = nn.Dropout(drop)
    
        def forward(self, x, H, W):
            x = self.fc1(x)
            x = self.dwconv(x, H, W)
            x = self.act(x)
            x = self.drop(x)
            x = self.fc2(x)
            x = self.drop(x)
            return x
    

2. Lightweight ALL-MLP Decoder

Decoder是由MLP组成的。能够使用如此简单的decoder是因为hierarchical Transformer encoder比传统的CNN encoders有较大的有效感知域。
在这里插入图片描述

可以看到,Decoder分为四个步骤:

  • multi-level features F i F_i Fi 会输入到MLP层,使得通道的维度一样。

    F ^ i = Linear ( C i , C ) ( F i ) , ∀ i \hat{F}_i = \text{Linear}(C_i, C)(F_i), \forall i F^i=Linear(Ci,C)(Fi),i

  • features map会进行上采样到原图的 1 4 \frac{1}{4} 41,并且拼接在一起。
    F ^ i = Upsample ( W 4 × W 4 ) ( F ^ i ) , ∀ i \hat{F}_i = \text{Upsample}(\frac{W}{4} \times \frac{W}{4})(\hat{F}_i), \forall i F^i=Upsample(4W×4W)(F^i),i

  • 再次使用MLP层把拼接的特征融合。
    F = Linear ( 4 C , C ) ( Concat ( F ^ i ) ) , ∀ i F = \text{Linear}(4C, C)(\text{Concat}(\hat{F}_i)), \forall i F=Linear(4C,C)(Concat(F^i)),i

  • 另外一个MLP层將融合的特征进行segmentation mask的预测。输入维度为 H 4 × W 4 × N c l s \frac{H}{4} \times \frac{W}{4} \times N_cls 4H×4W×Ncls
    M = Linear ( C , N c l s ) ( F ) M = \text{Linear}(C, N_{cls})(F) M=Linear(C,Ncls)(F)

通过下面源码可知,第一步的Linear是FC层,Upsample是bilinear插值法。第三步的特征融合和第四步预测都是卷积操作。

from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule


class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x

class SegFormerHead(BaseDecodeHead):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, feature_strides, **kwargs):
        super(SegFormerHead, self).__init__(input_transform='multiple_select', **kwargs)
        assert len(feature_strides) == len(self.in_channels)
        assert min(feature_strides) == feature_strides[0]
        self.feature_strides = feature_strides

        c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels

        decoder_params = kwargs['decoder_params']
        embedding_dim = decoder_params['embed_dim']

        self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
        self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
        self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
        self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)

        self.linear_fuse = ConvModule(
            in_channels=embedding_dim*4,
            out_channels=embedding_dim,
            kernel_size=1,
            norm_cfg=dict(type='SyncBN', requires_grad=True)
        )

        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)

    def forward(self, inputs):
        x = self._transform_inputs(inputs)  # len=4, 1/4,1/8,1/16,1/32
        c1, c2, c3, c4 = x

        ############## MLP decoder on C1-C4 ###########
        n, _, h, w = c4.shape

        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
        #  F.interpolate(input, size, scale_factor, mode, align_corners)
        _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
        _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
        _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])

        _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))

        x = self.dropout(_c)
        x = self.linear_pred(x)

        return x

3. 整体结构

不同大小的网络结构如下图所示。

所对应的参数分别为:

四、 Conclusion

提出了一个结合了positional-encoding-free,hierarchical Transformer encoder和lightweight ALL-MLP decoder语义分割网络。改进了ViT和应用了PVT。

Reference

  • 21
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值