MetaFormer(CVPR 2022,Sea)

paper:MetaFormer Is Actually What You Need for Vision

official implementation:https://github.com/sail-sg/poolformer

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/metaformer.py

背景

这篇文章讨论了Transformer在计算机视觉任务中的应用,并提出了一种新的框架MetaFormer。研究者注意到,尽管Transformer的attention机制通常被认为是其优越性能的关键,但最近的研究表明,使用其他简单的token mixer(如spatial MLP)也能取得良好效果。因此,文章探讨了Transformer的通用架构MetaFormer,而不仅仅是特定的token mixer。

出发点

文章的出发点是验证一个假设,即Transformer模型的成功主要归因于其通用架构MetaFormer,而不是特定的token mixer(如attention)。为验证这一假设,研究者将Transformer中的attention模块替换为一个简单的空间池化操作,形成一个新的模型——PoolFormer。

创新点

  1. MetaFormer框架:提出并验证了MetaFormer这一通用架构,证明其比特定的token mixer(如attention)更为关键。
  2. 简单的token mixer:通过使用简单的空间池化操作,研究者验证了即使没有复杂的token mixer,MetaFormer架构也能实现优越的性能。
  3. 性能验证:实验结果显示,PoolFormer在ImageNet-1K分类任务中超过了精调的Vision Transformer和MLP-like基线模型,参数和计算量都显著减少。

总体而言,文章通过引入和验证MetaFormer这一新的通用架构,对Transformer和MLP-like模型在计算机视觉任务中的成功原因进行了深入探讨,并为未来的研究提供了新的方向。

方法介绍

MetaFormer

MetaFormer的整体结构如图1所示,输入 \(I\) 首先通过input embedding的处理,比如ViTs中的patch embedding

其中 \(X\in\mathbb{R}^{N\times C}\) 表示序列长度为 \(N\) 嵌入维度为 \(C\) 的embedding token,然后输入到重复的MetaFormer block中,每个block包含两个residual sub-blocks。具体来说,第一个sub-block包含一个token mixer用来进行token之间的信息交互,表示如下

其中 \(Norm(\cdot)\) 表示规范化比如Layer Normalization或Batch Normalization。\(TokenMixer(\cdot)\) 表示一个用于混合token信息的模块,比如Transformer中的各种attention机制或MLP类模型中的spatial MLP。注意token mixer的主要作用是用于混合token信息,尽管有些token mixer同时也混合通道信息例如attention。

第二个sub-block主要由一个具有非线性激活的两层MLP组成

其中 \(W_1\in \mathbb{R}^{C\times rC}\) 和 \(W_2\in \mathbb{R}^{rC\times C}\) 是MLP的可学习参数,\(r\) 是expansion ratio,\(\sigma(\cdot)\) 是非线性激活函数例如ReLU或GELU。

MetaFormer描述了一种通用的体系结构,通过指定token mixer的具体设计,可以立即获得不同的模型。例如如果将token mixer指定为attention或spatial MLP,就可以得到Transformer或MLP类的模型。

PoolFormer

作者认为Transformer或MLP类模型的成功主要是因为MetaFormer这种通用的架构,为了证明这点,作者故意使用了一个非常简单的操作即池化来作为token mixer,它没有可学习的参数,只是通过平均聚合每个位置周围token的特征来得到该位置的输出,对于输入 \(T^{C\times H\times W}\),池化可以表示为

其中 \(K\) 是pooling size,注意由于MetaFormer block本身有一个residual connection,这里减去了自身。

PoolFormer的整体结构如图2所示,这里和CNN一样采用了一种金字塔结构。具体来说一共有4个stage,每个stage的分辨率减半,一共有两组embedding size,对于小尺寸模型四个stage的embedding size分别为64,128,320,512,对于大尺寸模型为96,192,384,768。假设模型一共有 \(L\) 个block,四个stage的block数量分别为L/6, L/6, L/2, L/6。不同PoolFormer变种的具体配置如下

此外,作者对Layer Normalization进行了修改,和原始的LN沿通道维度计算均值和方差不同,这里修改后的LN即Modified Layer Normalization(MLN)沿通道和token维度计算均值和方差。

实验结果

在ImageNet上和其它结构的模型的性能对比如下表所示,可以看到,PoolFormer尽管结构简单也取得了极具竞争力或者更好的结果。

在COCO目标检测任务上,效果好好于ResNet

在ADE20K语义分割任务上效果也很好。

此外作者还进行了各个componets的消融实验,结果如下。首先是token mixer,可以看到即使换成identity mapping精度也达到了74.3%,表明更多起作用的是MetaFormer的整体结构,而不是token mixer的具体形式。本文提出的MLN比原始的LN效果要更好。如果将residual connection和channel mlp去掉,网络直接不收敛了,表明MetaFormer整体结构的重要性。

此外基于pooling的token mixer可以处理更长的序列,而基于attention的token mixer更适合捕获全局信息,因此作者又试验了一种混合结构,即早期stage采用pooling token mixer,后期stage采用attention token mixer,如最后一栏所示,可以看到这种混合结构是有助于提高模型性能的。

代码解析

这里以timm中的实现为例,输入大小为(1, 3, 224, 224),模型选择"poolformer_s24",模型配置如下

def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
    model_kwargs = dict(
        depths=[4, 4, 12, 4],
        dims=[64, 128, 320, 512],
        downsample_norm=None,
        mlp_act=nn.GELU,
        mlp_bias=True,
        norm_layers=GroupNorm1,
        layer_scale_init_values=1e-5,
        res_scale_init_values=None,
        use_mlp_head=False,
        **kwargs)
    return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)

MetaFormerBlock的代码如下

class MetaFormerBlock(nn.Module):
    """
    Implementation of one MetaFormer block.
    """

    def __init__(
            self,
            dim,
            token_mixer=Pooling,
            mlp_act=StarReLU,
            mlp_bias=False,
            norm_layer=LayerNorm2d,
            proj_drop=0.,
            drop_path=0.,
            use_nchw=True,
            layer_scale_init_value=None,
            res_scale_init_value=None,
            **kwargs
    ):
        super().__init__()
        ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw)
        rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw)

        self.norm1 = norm_layer(dim)
        self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs)
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
        self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()

        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(
            dim,
            int(4 * dim),
            act_layer=mlp_act,
            bias=mlp_bias,
            drop=proj_drop,
            use_conv=use_nchw,
        )
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
        self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()

    def forward(self, x):
        x = self.res_scale1(x) + \
            self.layer_scale1(
                self.drop_path1(
                    self.token_mixer(self.norm1(x))
                )
            )
        x = self.res_scale2(x) + \
            self.layer_scale2(
                self.drop_path2(
                    self.mlp(self.norm2(x))
                )
            )
        return x

其中self.token_mixer为平均池化并且减去输入,代码如下

class Pooling(nn.Module):
    """
    Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
    """

    def __init__(self, pool_size=3, **kwargs):
        super().__init__()
        self.pool = nn.AvgPool2d(
            pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)

    def forward(self, x):
        y = self.pool(x)
        return y - x

Modified Layer Normalization可以通过将pytorch自带的GroupNorm的group number设置为1来实现,代码如下

class GroupNorm1(nn.GroupNorm):
    """ Group Normalization with 1 group.
    Input: tensor in shape [B, C, *]
    """

    def __init__(self, num_channels, **kwargs):
        super().__init__(1, num_channels, **kwargs)
        self.fast_norm = is_fast_norm()  # can't script unless we have these flags here (no globals)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.fast_norm:
            return fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
        else:
            return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值