EfficientFormer v2(ICCV 2023, Snap)原理与代码解析

paper:Rethinking Vision Transformers for MobileNet Size and Speed

official implementation:https://github.com/snap-research/efficientformer

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/efficientformer_v2.py

背景

Vision Transformers (ViTs) 在计算机视觉任务中取得了显著的成功,但在移动设备上的部署仍然面临挑战。最近一些轻量化的ViTs在设计时只考虑了一个指标,在其它指标上竞争力较低。比如MobileViT参数量很少但和轻量化CNN相比推理速度慢了好几倍,EfficientFormer在移动设备上速度很快但参数量较大,LeViT和MobileFormer以冗余的参数为代价实现了较低的FLOPs。本文的核心问题是: 是否可以设计出与MobileNet在速度和尺寸上相当的Transformer模型,并且保持较高的性能?

创新点

本文在EfficientFormer的基础上,重新审视了ViTs的设计选择,并提出了一种低延迟和参数效率高的新型supernet。通过引入一种新的细粒度联合搜索策略,该模型可以在优化延迟和参数数量的同时,找到高效的架构。 具体包括

  1. 新型超网设计:提出了一个低延迟、高参数效率的supernet,用于设计高效的Transformer模型。
  2. 细粒度联合搜索策略:引入了一种新的搜索策略,能够同时优化延迟和参数数量,从而找到最优的模型架构。定义了一个新的评估指标 MES,用于评估网络在移动设备上的性能,综合考虑了模型大小和延迟。
  3. 性能提升:所提出的EfficientFormerV2模型在ImageNet-1K数据集上比MobileNetV2的top-1准确率高出3.5%,且具有相似的延迟和参数数量。

方法介绍

作者首先以EfficientFormer为baseline,重新研究了高效ViTs的设计选择。 

Token Mixers vs. Feed Forward Network

结合局部信息可以提高性能,使ViTs在没有显示position embedding的情况下更具鲁棒性。PoolFormer和EfficientFormer使用3x3平均池化作为local token mixer,如图2(a)所示。使用相同大小的深度卷积不会引入额外的延迟开销,通过增加可忽略的额外参数(0.02M)性能提高了0.6%。此外最近的研究表明,在ViT的FFN中注入局部信息也有助于提高性能。但值得注意的是在FFN中通过额外的3x3深度卷积来捕获局部信息,和原本的local mixer的功能重叠了。因此作者去除了residual-connected local token mixer并把3x3深度卷积移到了FFN中,得到了unified FFN,如图2(b)所示。作者将unified FFN应用到了所有stage中,这一设计的修改将性能提高到了80.3%同时参数只增加了0.1M,如表1所示。

Search Space Refinement

作者更改了网络的深度和宽度,发现更深更窄的网络精度更高(提升了0.2%)参数更少(减少了0.13M)延迟更低(降低0.1ms),如表1所示。因此作者将此网络作为一个新的baseline(精度80.5%)并在此基础上进行后续的设计修改。

在其它网络中5个stage的设计被广泛使用。作者也进行了5个stage的实验结果如表1所示,精度下降到了80.31%同时参数增加了0.39M延迟增加了0.2ms。因此作者还是保持4个stage的设计。

MHSA Improvements

作者研究了在不增加模型大小和延迟的额外开销下提高注意力模块性能的技术。如图2(c)所示,包括两种方法,一是通过增加一个3x3深度卷积向Value matrix中注入局部信息。二是通过在head dimension增加全连接层来实现attention heads之间的通信,即Talking Head(具体介绍见Talking-Heads Attention-CSDN博客)。通过这两点改进,性能进一步提升到了80.8%。

Attention on Higher Resolution

注意力机制对性能有帮助,但因其相对于分辨率的二次时间复杂度将其应用于高分辨率输入会降低效率。之前的baseline中只在最后一个分辨率为输入 \(\frac{1}{32}\) 

的stage应用attention,作者尝试在倒数第二个stage中也应用attention,精度提升了0.9%但推理速度慢了2.8倍。因此作者思考了如何降低attention模块的复杂度。尽管之前的工作有通过window-based attention或降采样Keys和Values来缓解这一问题,但作者发现这些方法不适用于移动部署。Window-based attention在移动设备上很难进行加速因为其复杂的window partitioning和reordering操作。而降采样Keys和Values需要保持Queries的分辨率从而在attention矩阵相乘后保持输出分辨率不变,作者通过实验发现这样虽然将延迟降低到了2.8ms但仍然比baseline慢了2倍。

因此作者直接将Query、Key、Value都降采样到一个固定的大小 \((\frac{1}{32})\) 并通过插值来恢复分辨率,如图2(d)(e)所示。如表1所示,这一操作将延迟从3.5ms降到了1.5ms,同时保持了具有竞争力的精度(81.5% vs. 81.7%)。

Dual-Path Attention Downsampling

大多数的视觉backbones采用strided conv或pooling层执行静态或动态的降采样,并形成层级结构。最近的一些研究开始探索注意力降采样,例如LeViT和UniNet提出通过注意力机制将分辨率减半,以实现具有全局感受野的上下文感知降采样。具体来说,query中的token数量减半,从而注意力模块的输出也减半

为了在移动设备上实现可接受的推理速度,将注意力降采样应用于early stage的高分辨率输入是不合适的,这限制了现有的在高分辨率下搜索不同降采样方法的价值。

本文提出了一种组合策略,dual-path attention downsampling双路径注意力降采样,它同时具有局部和全局依赖,如图2(f)所示。为了得到降采样的query,我们使用pooling作为静态局部降采样,3x3深度卷积作为可学习的局部降采样,然后将两者的结果合并然后映射到query的维度。此外,attention downsampling模块整体还通过残差连接与一个普通的strided conv连接起来,形成local-global的方式,类似于downsampling bottleneck或inverted bottleneck。通过双路注意力降采样将精度进一步提升到了81.8%。

Jointly Optimizing Model Size and Speed

基线模型EfficientFormer是通过latency-driven搜索发现的,并在移动设备上的速度很快。但存在两个问题:一是搜索过程只对速度进行了限制,导致最终的模型参数冗余。二是只搜索了depth和width,是一种coarse-grained粗粒度方式。实际上,网络的大部分参数和计算都在FFN中,而FFN的参数量和计算复杂度与expansion ratio呈线性相关,之前的网络所有stage的FFN的expansion ratio都相等,如果搜索expansion ratio可以得到一个更fine-grained细粒度的搜索空间。因此本文提出了一种搜索算法,它使用灵活的per-block配置,并联合约束大小和速度。

接下来具体的搜索目标、搜索空间、supernet和搜索算法就不具体介绍了。

实验结果

模型的具体配置如下所示

在ImageNet上的结果如表2所示,可以看到相比于EfficientFormer v1,v2的在相似参数量和延迟下性能更高。 

代码解析

这里以timm的实现为例,输入大小为(1, 3, 224, 224),模型选择"efficientformerv2_s0"。

第一点改进是将early stage的poolformer改成了unified ffn,如图2(b)所示。EfficientFormerV2Block类的forward函数如下,其中前两个stage类的参数use_attn=False,因此self.token_mixer=None,这里只执行self.mlp。

def forward(self, x):
    if self.token_mixer is not None:
        x = x + self.drop_path1(self.ls1(self.token_mixer(x)))
    x = x + self.drop_path2(self.ls2(self.mlp(x)))
    return x

第一个stage的第一个block的self.mlp如下,可以看到和图2(b)是完全一样的。

ConvMlpWithNorm(
  (fc1): ConvNormAct(
    (conv): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNormAct2d(
      128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): GELU()
    )
  )
  (mid): ConvNormAct(
    (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=128)
    (bn): BatchNormAct2d(
      128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): GELU()
    )
  )
  (drop1): Dropout(p=0.0, inplace=False)
  (fc2): ConvNorm(
    (conv): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (drop2): Dropout(p=0.0, inplace=False)
)

第二点改进是MHSA的改进,如图2(c)所示。完整类实现代码如下,其中加了一些注释。首先是向V注入locality,文中图示好像是V经过3x3深度卷积后再计算attention矩阵,但实际实现是V经过3x3深度卷积得到的输出与原本的V计算的attention残差连接,即第95行和107行。然后是加入了heads之间的通信即talking head,具体介绍见Talking-Heads Attention-CSDN博客,实现为第101行和104行。

第三点改进是在倒数第二个分辨率为输入 \(\frac{1}{16}\) 的stage也加入attention,如图2(e)所示,在计算attention前通过strided depthwise conv进行降采样,计算完attention后再通过插值上采样回去,代码见第74行109行。

class Attention2d(torch.nn.Module):
    attention_bias_cache: Dict[str, torch.Tensor]

    def __init__(
            self,
            dim=384,  # 96
            key_dim=32,
            num_heads=8,
            attn_ratio=4,
            resolution=7,
            act_layer=nn.GELU,
            stride=None,  # 2
    ):
        super().__init__()
        self.num_heads = num_heads  # 8
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim  # 32

        resolution = to_2tuple(resolution)  # (14,14)
        if stride is not None:
            resolution = tuple([math.ceil(r / stride) for r in resolution])  # (7,7)
            self.stride_conv = ConvNorm(dim, dim, kernel_size=3, stride=stride, groups=dim)
            self.upsample = nn.Upsample(scale_factor=stride, mode='bilinear')
        else:
            self.stride_conv = None
            self.upsample = None

        self.resolution = resolution  # (7,7)
        self.N = self.resolution[0] * self.resolution[1]
        self.d = int(attn_ratio * key_dim)  # 4*32=128
        self.dh = int(attn_ratio * key_dim) * num_heads  # 128*8=1024
        self.attn_ratio = attn_ratio  # 4
        kh = self.key_dim * self.num_heads  # 32*8=256

        self.q = ConvNorm(dim, kh)  # 96,256
        self.k = ConvNorm(dim, kh)  # 96,256
        self.v = ConvNorm(dim, self.dh)  # 96,1024
        self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, groups=self.dh)
        self.talking_head1 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1)
        self.talking_head2 = nn.Conv2d(self.num_heads, self.num_heads, kernel_size=1)

        self.act = act_layer()
        self.proj = ConvNorm(self.dh, dim, 1)

        pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)
        rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
        rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]
        self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, self.N))
        self.register_buffer('attention_bias_idxs', torch.LongTensor(rel_pos), persistent=False)
        self.attention_bias_cache = {}  # per-device attention_biases cache (data-parallel compat)

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and self.attention_bias_cache:
            self.attention_bias_cache = {}  # clear ab cache

    def get_attention_biases(self, device: torch.device) -> torch.Tensor:
        if torch.jit.is_tracing() or self.training:
            return self.attention_biases[:, self.attention_bias_idxs]
        else:
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
            return self.attention_bias_cache[device_key]

    def forward(self, x):  # (1,96,14,14)
        B, C, H, W = x.shape
        if self.stride_conv is not None:
            # ConvNorm(
            #   (conv): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96)
            #   (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            # )
            x = self.stride_conv(x)  # (1,96,7,7)

        # ConvNorm(
        #   (conv): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        q = self.q(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)  # (1,256,7,7)->(1,8,32,49)->(1,8,49,32)
        # ConvNorm(
        #   (conv): Conv2d(96, 256, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)  # (1,256,7,7)->(1,8,32,49)->(1,8,32,49)
        # ConvNorm(
        #   (conv): Conv2d(96, 1024, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        v = self.v(x)  # (1,1024,7,7)
        # ConvNorm(
        #   (conv): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1024)
        #   (bn): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        v_local = self.v_local(v)  # (1,1024,7,7)
        v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)  # (1,8,128,49)->(1,8,49,128)

        attn = (q @ k) * self.scale  # (1,8,49,49)
        attn = attn + self.get_attention_biases(x.device)  # (1,8,49,49) + (8,49,49) -> (1,8,49,49)
        # Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
        attn = self.talking_head1(attn)  # (1,8,49,49)
        attn = attn.softmax(dim=-1)  # (1,8,49,49)
        # Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
        attn = self.talking_head2(attn)  # (1,8,49,49)

        x = (attn @ v).transpose(2, 3)  # (1,8,49,128)->(1,8,128,49)
        x = x.reshape(B, self.dh, self.resolution[0], self.resolution[1]) + v_local  # (1,1024,7,7)
        if self.upsample is not None:
            x = self.upsample(x)  # (1,1024,14,14)

        x = self.act(x)
        # ConvNorm(
        #   (conv): Conv2d(1024, 96, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        x = self.proj(x)  # (1,96,14,14)
        return x

最后一点改进是通过attention进行降采样,如图2(f)所示,实现如下。其中只在最后一个stage前的降采样采用attention downsampling,前面的降采样都是通过3x3-s2卷积。最后一个stage传入的use_attn=True,此时strided conv降采样和attention降采样相加得到最终结果。

class Downsample(nn.Module):
    def __init__(
            self,
            in_chs,
            out_chs,
            kernel_size=3,
            stride=2,
            padding=1,
            resolution=7,
            use_attn=False,
            act_layer=nn.GELU,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()

        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        padding = to_2tuple(padding)
        norm_layer = norm_layer or nn.Identity()
        self.conv = ConvNorm(
            in_chs,
            out_chs,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            norm_layer=norm_layer,
        )

        if use_attn:
            self.attn = Attention2dDownsample(
                dim=in_chs,  # 96
                out_dim=out_chs,  # 176
                resolution=resolution,  # (14,14)
                act_layer=act_layer,
            )
        else:
            self.attn = None

    def forward(self, x):
        out = self.conv(x)
        if self.attn is not None:
            return self.attn(x) + out
        return out

其中attention downsampling的代码如下。首先通过self.q(x)得到q,这里self.q=LocalGlobalQuery就是图3(f)中Q前面的Pool和Conv部分的实现,这里通过池化和strided卷积对q进行降采样。另外第103行和110行这里也引入了locality,但图3(f)中没画出来。

class LocalGlobalQuery(torch.nn.Module):
    def __init__(self, in_dim, out_dim):  # 96,128
        super().__init__()
        self.pool = nn.AvgPool2d(1, 2, 0)
        self.local = nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=2, padding=1, groups=in_dim)
        self.proj = ConvNorm(in_dim, out_dim, 1)

    def forward(self, x):  # (1,96,14,14)
        # Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96)
        local_q = self.local(x)  # (1,96,7,7)
        # AvgPool2d(kernel_size=1, stride=2, padding=0)
        pool_q = self.pool(x)  # (1,96,7,7)
        q = local_q + pool_q  # (1,96,7,7)
        # ConvNorm(
        #   (conv): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        q = self.proj(q)  # (1,128,7,7)
        return q


class Attention2dDownsample(torch.nn.Module):
    attention_bias_cache: Dict[str, torch.Tensor]

    def __init__(
            self,
            dim=384,  # 96
            key_dim=16,
            num_heads=8,
            attn_ratio=4,
            resolution=7,
            out_dim=None,  # 176
            act_layer=nn.GELU,
    ):
        super().__init__()

        self.num_heads = num_heads  # 8
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim  # 16
        self.resolution = to_2tuple(resolution)  # (14,14)
        self.resolution2 = tuple([math.ceil(r / 2) for r in self.resolution])  # (7,7)
        self.N = self.resolution[0] * self.resolution[1]  # 196
        self.N2 = self.resolution2[0] * self.resolution2[1]  # 49

        self.d = int(attn_ratio * key_dim)  # 4*16=64
        self.dh = int(attn_ratio * key_dim) * num_heads  # 64*8=512
        self.attn_ratio = attn_ratio  # 4
        self.out_dim = out_dim or dim  # 176
        kh = self.key_dim * self.num_heads  # 16x8=128

        self.q = LocalGlobalQuery(dim, kh)  # 96,128
        self.k = ConvNorm(dim, kh, 1)  # 96,128
        self.v = ConvNorm(dim, self.dh, 1)  # 96,512
        self.v_local = ConvNorm(self.dh, self.dh, kernel_size=3, stride=2, groups=self.dh)

        self.act = act_layer()
        self.proj = ConvNorm(self.dh, self.out_dim, 1)

        self.attention_biases = nn.Parameter(torch.zeros(num_heads, self.N))  # (8,196)
        k_pos = torch.stack(ndgrid(torch.arange(self.resolution[0]), torch.arange(self.resolution[1]))).flatten(1)  # (2,196)
        q_pos = torch.stack(ndgrid(
            torch.arange(0, self.resolution[0], step=2),
            torch.arange(0, self.resolution[1], step=2)
        )).flatten(1)  # (2,49)
        rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()  # (2,49,196)
        rel_pos = (rel_pos[0] * self.resolution[1]) + rel_pos[1]  # (49,196)
        self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)  # (49,196)
        self.attention_bias_cache = {}  # per-device attention_biases cache (data-parallel compat)

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and self.attention_bias_cache:
            self.attention_bias_cache = {}  # clear ab cache

    def get_attention_biases(self, device: torch.device) -> torch.Tensor:
        if torch.jit.is_tracing() or self.training:
            return self.attention_biases[:, self.attention_bias_idxs]
        else:
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
            return self.attention_bias_cache[device_key]

    def forward(self, x):  # (1,96,14,14)
        B, C, H, W = x.shape

        q = self.q(x).reshape(B, self.num_heads, -1, self.N2).permute(0, 1, 3, 2)  # (1,128,7,7)->(1,8,16,49)->(1,8,49,16)
        # ConvNorm(
        #   (conv): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        k = self.k(x).reshape(B, self.num_heads, -1, self.N).permute(0, 1, 2, 3)  # (1,128,14,14)->(1,8,16,196)->(1,8,16,196)
        # ConvNorm(
        #   (conv): Conv2d(96, 512, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        v = self.v(x)  # (1,512,14,14)
        # ConvNorm(
        #   (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
        #   (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        v_local = self.v_local(v)  # (1,512,7,7)
        v = v.reshape(B, self.num_heads, -1, self.N).permute(0, 1, 3, 2)  # (1,512,14,14)->(1,8,64,196)->(1,8,196,64)

        attn = (q @ k) * self.scale  # (1,8,49,196)
        attn = attn + self.get_attention_biases(x.device)  # (1,8,49,196) + (8,49,196) -> (1,8,49,196)
        attn = attn.softmax(dim=-1)  # (1,8,49,196)
        x = (attn @ v).transpose(2, 3)  # (1,8,49,64)->(1,8,64,49)
        x = x.reshape(B, self.dh, self.resolution2[0], self.resolution2[1]) + v_local  # (1,512,7,7)
        x = self.act(x)
        # ConvNorm(
        #   (conv): Conv2d(512, 176, kernel_size=(1, 1), stride=(1, 1))
        #   (bn): BatchNorm2d(176, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        # )
        x = self.proj(x)  # (1,176,7,7)
        return x

  • 27
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值