RepViT:从ViT的角度重新审视mobile CNN

RepViT: Revisiting Mobile CNN From ViT Perspective

摘要

近年来,与轻量级卷积神经网络(cnn)相比,轻量级视觉变压器(ViTs)在资源受限的移动设备上表现出了更高的性能和更低的延迟。这种改进通常归功于多头自注意模块,它使模型能够学习全局表示。然而,轻量级vit和轻量级cnn之间的架构差异还没有得到充分的研究。在这项研究中,我们重新审视了轻量级cnn的高效设计,并强调了它们在移动设备上的潜力。通过集成轻量级vit的高效架构选择,我们逐步增强了标准轻量级CNN的移动友好性,特别是MobileNetV3。这就产生了一个新的纯轻量级cnn家族,即RepViT。大量的实验表明,RepViT优于现有的轻型vit,并在各种视觉任务中表现出良好的延迟。在ImageNet上,RepViT在iPhone 12上以近1ms的延迟实现了超过80%的top-1精度,据我们所知,这是轻量级模型的第一次。
代码地址
在这里插入图片描述

本文方法

在这里插入图片描述
在这里插入图片描述
图4。(a)表示MobileNetV3的块,具有可选的SE。在(b)中,我们通过重新定位SE,采用结构重参数化来分离token mixer和channel mixer。©涉及在推理阶段将多分支拓扑整合为单个分支。
原始的MobileNetV3块由1x1的扩展卷积,然后是深度卷积和1x1的投影层组成。剩余连接连接输入和输出。此外,挤压和激励模块可以任选地放置在扩展中的深度滤波器之后。从直观上看,1x1展开卷积和1x1投影层实现了通道间的交互,而深度卷积则实现了空间信息的融合.
前者和后者分别对应于通道混频器和令牌混频器。令牌混频器和通道混频器现在在MobileNetV3块中耦合在一起。因此,如图4 (b)所示,我们将深度卷积向上移动以拆分它们。同时,我们采用结构重参数化,在训练时为深度滤波器引入多分支拓扑,以提高性能。挤压和激励模块也被上移到深度滤波器之后,因为它依赖于空间信息交互。
因此,我们成功地分离了MobileNetV3块中的令牌混频器和通道混频器。此外,在推理过程中,如图4 ©所示,令牌混合器的多分支拓扑被合并为单个深度卷积。
在这里插入图片描述
图5。(a)为MobileNetV3-L中的原始主干,为简单起见,非线性部分略去。我们使用早期卷积作为(b)中的干
在这里插入图片描述
图6。(a)为MobileNetV3-L区块的原始下采样层。采用RepViT块设计后变为(b)。在©中,分别通过分别使用深度卷积和1x1卷积来调制特征图分辨率和通道维度。通过在前面合并一个RepViT块和在后面合并一个FFN,从而加深了所得的下采样层,增强了其整体架构。为简单起见,省略了非线性

在这里插入图片描述

代码

class RepViTBlock(nn.Module):
    def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
        super(RepViTBlock, self).__init__()
        assert stride in [1, 2]

        self.identity = stride == 1 and inp == oup
        assert(hidden_dim == 2 * inp)

        if stride == 2:
            self.token_mixer = nn.Sequential(
                Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
                Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(oup, 2 * oup, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
                ))
        else:
            assert(self.identity)
            self.token_mixer = nn.Sequential(
                RepVGGDW(inp),
                SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
            )
            self.channel_mixer = Residual(nn.Sequential(
                    # pw
                    Conv2d_BN(inp, hidden_dim, 1, 1, 0),
                    nn.GELU() if use_hs else nn.GELU(),
                    # pw-linear
                    Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
                ))

    def forward(self, x):
        return self.channel_mixer(self.token_mixer(x))

class RepVGGDW(torch.nn.Module):
    def __init__(self, ed) -> None:
        super().__init__()
        self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
        self.conv1 = Conv2d_BN(ed, ed, 1, 1, 0, groups=ed)
        self.dim = ed
    
    def forward(self, x):
        return self.conv(x) + self.conv1(x) + x
    
    @torch.no_grad()
    def fuse(self):
        conv = self.conv.fuse()
        conv1 = self.conv1.fuse()
        
        conv_w = conv.weight
        conv_b = conv.bias
        conv1_w = conv1.weight
        conv1_b = conv1.bias
        
        conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])

        identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])

        final_conv_w = conv_w + conv1_w + identity
        final_conv_b = conv_b + conv1_b

        conv.weight.data.copy_(final_conv_w)
        conv.bias.data.copy_(final_conv_b)
        return conv

实验结果

在这里插入图片描述
在这里插入图片描述

  • 5
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小杨小杨1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值