【Moga Net】Efficient Multi-order Gated Aggregation Network 阅读

Efficient Multi-order Gated Aggregation Network

导读

今天分享一篇发在arxiv上的backbone基础网络创新论文【MogaNet】,作者来自于西湖大学和浙江大学。MogaNet 将DNN的表征能力表示为lens of interaction complexities。MogaNet为纯卷积网络,通过类似于人类视觉系统的多阶特征聚集模块同时具备局部感知和全局上下文聚集能力,通过特殊设计的SMixer和CMixer 分别从Spatial维度和Channel维度聚集上下文信息,实现了计算复杂度-性能间的平衡,在分类、检测、分割任务上取得了不错性能指标。
标题
请添加图片描述

背景

卷积层利用区域密集连接和平移不变性约束对输入图像的邻域相关性进行编码。卷积网络通过级联的卷积层逐渐增加感受野,自适应的识别语义特征。但是卷积网络学习到的特征更关注于图像局部纹理特征,缺乏对图像全局上下文信息的获取。通过宏观网络设计(ASPP,FPN)和增加全局上下文aggregation (SE-Block,GC-Block) 可以部分解决该问题。

基于自注意力的Transformer更关注于图像全局依赖,在很多领域都超过了CNN的性能,
但是Transformer的二次方计算复杂度限制了网络的计算效率和在下游密集预测任务上的应用。由于缺少CNN的归纳偏置特性,Transformer学习图像相关邻域关系能力较弱。最近有研究将CNN中的FPN金字塔结构和重新引入ViT。

人类视觉系统不仅提取局部特征,同时也聚集全局感知特征,比深度学习更紧凑和高效。大多数现代DNN倾向于编码极低或极高复杂性的相互作用,而不是信息最丰富的中间作用。
lens of interaction complexities
如图所示,人类视觉特征在遮挡50%左右时即可获取几乎全部图像的信息,而DNN在10%以下/90%以上时获取更多的信息。

网络结构

MogaNet 提出的宏观网络结构。经过stem调整输入维度,SMixer获取空间特征,CMixer 混合channel。
MogaBlock

SMixer
CMixer

整体网络结构
MogaNet

消融实验

不同的上下文聚集方法
MLP 的channel扩展倍数对参数和性能的影响
不同模块对性能和参数的影响
不同Normalizaiton的影响
不同激活函数的影响

不同通道比例对性能的影响
不同模块对性能的影响和对 Interaction Strength的影响

特征可视化

CAM可视化

实验结果

大模型分类

官方代码库

https://github.com/Westlake-AI/openmixup

MogaNet

ElementScale

class ElementScale(nn.Module):
    """A learnable element-wise scaler."""

    def __init__(self, embed_dims, init_value=0., requires_grad=True):
        super(ElementScale, self).__init__()
        self.scale = nn.Parameter(
            init_value * torch.ones((1, embed_dims, 1, 1)),
            requires_grad=requires_grad
        )

    def forward(self, x):
        return x * self.scale

Moga Block

class MogaBlock(BaseModule):
    """A block of MogaNet.
    Args:
        embed_dims (int): Number of input channels.
        ffn_ratio (float): The expansion ratio of feedforward network hidden
            layer channels. Defaults to 4.
        drop_rate (float): Dropout rate after embedding. Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.1.
        act_cfg (dict, optional): The activation config for projections and FFNs.
            Default: dict(type='GELU').
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='BN')``.
        init_value (float): Init value for Layer Scale. Defaults to 1e-5.
        attn_dw_dilation (list): Dilations of three DWConv layers.
        attn_channel_split (list): The raletive ratio of splited channels.
        attn_act_cfg (str): The activation config for the gating branch.
            Default: dict(type='SiLU').
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 ffn_ratio=4.,
                 drop_rate=0.,
                 drop_path_rate=0.,
                 act_cfg=dict(type='GELU'),
                 norm_cfg=dict(type='BN', eps=1e-5),
                 init_value=1e-5,
                 attn_dw_dilation=[1, 2, 3],
                 attn_channel_split=[1, 3, 4],
                 attn_act_cfg=dict(type='SiLU'),
                 attn_force_fp32=False,
                 init_cfg=None):
        super(MogaBlock, self).__init__(init_cfg=init_cfg)
        self.out_channels = embed_dims

        self.norm1 = build_norm_layer(norm_cfg, embed_dims)[1]

        # spatial attention
        self.attn = MultiOrderGatedAggregation(
            embed_dims,
            attn_dw_dilation=attn_dw_dilation,
            attn_channel_split=attn_channel_split,
            attn_act_cfg=attn_act_cfg,
            attn_force_fp32=attn_force_fp32,
        )
        self.drop_path = DropPath(
            drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        self.norm2 = build_norm_layer(norm_cfg, embed_dims)[1]

        # channel MLP
        mlp_hidden_dim = int(embed_dims * ffn_ratio)
        self.mlp = ChannelAggregationFFN(  # DWConv + Channel Aggregation FFN
            embed_dims=embed_dims,
            feedforward_channels=mlp_hidden_dim,
            act_cfg=act_cfg,
            ffn_drop=drop_rate,
        )

        # init layer scale
        self.layer_scale_1 = nn.Parameter(
            init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)
        self.layer_scale_2 = nn.Parameter(
            init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True)

    def forward(self, x):
        # spatial
        identity = x
        x = self.layer_scale_1 * self.attn(self.norm1(x))
        x = identity + self.drop_path(x)
        # channel
        identity = x
        x = self.layer_scale_2 * self.mlp(self.norm2(x))
        x = identity + self.drop_path(x)
        return x

SMixer

class MultiOrderGatedAggregation(BaseModule):
   """Spatial Block with Multi-order Gated Aggregation.
   Args:
       embed_dims (int): Number of input channels.
       attn_dw_dilation (list): Dilations of three DWConv layers.
       attn_channel_split (list): The raletive ratio of splited channels.
       attn_act_cfg (dict, optional): The activation config for FFNs.
           Default: dict(type='GELU').
       init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
           Default: None.
   """

   def __init__(self,
                embed_dims,
                attn_dw_dilation=[1, 2, 3],
                attn_channel_split=[1, 3, 4],
                attn_act_cfg=dict(type="SiLU"),
                attn_force_fp32=False,
                init_cfg=None):
       super(MultiOrderGatedAggregation, self).__init__(init_cfg=init_cfg)

       self.embed_dims = embed_dims
       self.attn_force_fp32 = attn_force_fp32
       self.proj_1 = Conv2d(
           in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
       self.gate = Conv2d(
           in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
       self.value = MultiOrderDWConv(
           embed_dims,
           dw_dilation=attn_dw_dilation,
           channel_split=attn_channel_split,
       )
       self.proj_2 = Conv2d(
           in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

       # activation for gating and value
       self.act_value = custom_build_activation_layer(attn_act_cfg)
       self.act_gate = custom_build_activation_layer(attn_act_cfg)

       # decompose
       self.sigma = ElementScale(
           embed_dims, init_value=1e-5, requires_grad=True)

   def feat_decompose(self, x):
       x = self.proj_1(x)
       # x_d: [B, C, H, W] -> [B, C, 1, 1]
       x_d = F.adaptive_avg_pool2d(x, output_size=1)
       x = x + self.sigma(x - x_d)
       x = self.act_value(x)
       return x

   @force_fp32()
   def forward_gating(self, g, v):
       g = g.to(torch.float32)
       v = v.to(torch.float32)
       return self.proj_2(self.act_gate(g) * self.act_gate(v))

   def forward(self, x):
       shortcut = x.clone()
       # proj 1x1
       x = self.feat_decompose(x)
       # gating and value branch
       g = self.gate(x)
       v = self.value(x)
       # aggregation
       if not self.attn_force_fp32:
           x = self.proj_2(self.act_gate(g) * self.act_gate(v))
       else:
           x = self.forward_gating(self.act_gate(g), self.act_gate(v))
       x = x + shortcut
       return x

multi-orderDWConv

class MultiOrderDWConv(BaseModule):
    """Multi-order Features with Dilated DWConv Kernel.
    Args:
        embed_dims (int): Number of input channels.
        dw_dilation (list): Dilations of three DWConv layers.
        channel_split (list): The raletive ratio of three splited channels.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 dw_dilation=[1, 2, 3,],
                 channel_split=[1, 3, 4,],
                 init_cfg=None):
        super(MultiOrderDWConv, self).__init__(init_cfg=init_cfg)

        self.split_ratio = [i / sum(channel_split) for i in channel_split]
        self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
        self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
        self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
        self.embed_dims = embed_dims
        assert len(dw_dilation) == len(channel_split) == 3
        assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
        assert embed_dims % sum(channel_split) == 0

        # basic DW conv
        self.DW_conv0 = Conv2d(
            in_channels=self.embed_dims,
            out_channels=self.embed_dims,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[0]) // 2,
            groups=self.embed_dims,
            stride=1, dilation=dw_dilation[0], #1
        )
        # DW conv 1
        self.DW_conv1 = Conv2d(
            in_channels=self.embed_dims_1,
            out_channels=self.embed_dims_1,
            kernel_size=5,
            padding=(1 + 4 * dw_dilation[1]) // 2,
            groups=self.embed_dims_1,
            stride=1, dilation=dw_dilation[1],# 1
        )
        # DW conv 2
        self.DW_conv2 = Conv2d(
            in_channels=self.embed_dims_2,
            out_channels=self.embed_dims_2,
            kernel_size=7,
            padding=(1 + 6 * dw_dilation[2]) // 2,
            groups=self.embed_dims_2,
            stride=1, dilation=dw_dilation[2],#2
        )
        # a channel convolution
        self.PW_conv = Conv2d(  # point-wise convolution
            in_channels=embed_dims,
            out_channels=embed_dims,
            kernel_size=1)

    def forward(self, x):
        x_0 = self.DW_conv0(x)
        x_1 = self.DW_conv1(
            x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
        x_2 = self.DW_conv2(
            x_0[:, self.embed_dims-self.embed_dims_2:, ...])
        x = torch.cat([
            x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
        x = self.PW_conv(x)
        return x

Channel Aggregation

class ChannelAggregationFFN(BaseModule):
    """An implementation of FFN with Channel Aggregation.
    Args:
        embed_dims (int): The feature dimension. Same as
            `MultiheadAttention`.
        feedforward_channels (int): The hidden dimension of FFNs.
        kernel_size (int): The depth-wise conv kernel size as the
            depth-wise convolution. Defaults to 3.
        ffn_drop (float, optional): Probability of an element to be
            zeroed in FFN. Default 0.0.
        init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
            Default: None.
    """

    def __init__(self,
                 embed_dims,
                 feedforward_channels,
                 kernel_size=3,
                 act_cfg=dict(type='GELU'),
                 ffn_drop=0.,
                 init_cfg=None):
        super(ChannelAggregationFFN, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims
        self.feedforward_channels = feedforward_channels
        self.act_cfg = act_cfg

        self.fc1 = Conv2d(
            in_channels=embed_dims,
            out_channels=self.feedforward_channels,
            kernel_size=1)
        self.dwconv = Conv2d(
            in_channels=self.feedforward_channels,
            out_channels=self.feedforward_channels,
            kernel_size=kernel_size,
            stride=1,
            padding=kernel_size // 2,
            bias=True,
            groups=self.feedforward_channels)
        self.act = build_activation_layer(act_cfg)
        self.fc2 = Conv2d(
            in_channels=feedforward_channels,
            out_channels=embed_dims,
            kernel_size=1)
        self.drop = nn.Dropout(ffn_drop)

        self.decompose = Conv2d(
            in_channels=self.feedforward_channels,  # C -> 1
            out_channels=1, kernel_size=1,
        )
        self.sigma = ElementScale(
            self.feedforward_channels, init_value=1e-5, requires_grad=True)
        self.decompose_act = build_activation_layer(act_cfg)

    def feat_decompose(self, x):
        # x_d: [B, C, H, W] -> [B, 1, H, W]
        x = x + self.sigma(x - self.decompose_act(self.decompose(x)))
        return x

    def forward(self, x):
        # proj 1
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.act(x)
        x = self.drop(x)
        # proj 2
        x = self.feat_decompose(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
        ```
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

vcbe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值