【CNN】——RMNET推理时去掉残差模块(代码解析)

残差模块,resnet,repvgg
code:https://github.com/fxmeng/RMNet

1. 解决的问题

  • resnet在推理时的分支不友好
  • repvgg模块因为会在模块外面添加relu层,导致模型深度有影响
    虽然残差连接可以训练深度非常深的神经网络,但由于其多分支拓扑结构,对在线推理并不友好。这鼓励了许多研究人员去设计没有残差连接的DNN。例如,RepVGG在部署时将多分支拓扑重新参数化为类VGG(单分支)结构,在网络相对较浅的情况下表现出良好的性能。然而,RepVGG不能将ResNet等效地转换为VGG,因为重新参数化方法只能应用于线性块,而非线性层(ReLU)必须放在残差连接之外,这导致了表示能力有限,特别是对于更深层次的网络。

RM操作作为一种plugin方法,基本上有3个优点:

  • 其实现使其对高比率网络剪枝比较友好
  • 突破了RepVGG的深度限制
  • 与ResNet和RepVGG相比,RMNet具有更好的精度-速度权衡网络

2. 模型与代码

  1. repvgg的问题
    在这里插入图片描述

从图2可以看出,随着深度的增加,ResNet可以得到更好的精度,这与前面的分析一致。相比之下,RepVGG-133在CIFAR-100上的准确率为79.57%,而RepVGG-133的准确率仅为41.38%。

  1. RM操作
    图3显示了RM操作等效去除残差连接的过程。为简单起见,在图中没有显示BN层,输入通道、中间通道和输出通道的数量相同,并赋值为C。
    在这里插入图片描述
    resblock代码实现:
class ResBlock(nn.Module):
    def __init__(self, in_planes, mid_planes, out_planes, stride=1):
        super(ResBlock, self).__init__()

        assert mid_planes > in_planes

        self.in_planes = in_planes
        self.mid_planes = mid_planes - out_planes +in_planes
        self.out_planes = out_planes
        self.stride = stride

        self.conv1 = nn.Conv2d(in_planes, self.mid_planes - in_planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.mid_planes - in_planes)
        
        self.conv2 = nn.Conv2d(self.mid_planes - in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        
        self.relu = nn.ReLU(inplace=True)
        
        self.downsample=nn.Sequential()
        if self.in_planes != self.out_planes or self.stride != 1:
            self.downsample=nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_planes))
        self.running1 = nn.BatchNorm2d(in_planes,affine=False)
        self.running2 = nn.BatchNorm2d(out_planes,affine=False)
        
    def forward(self, x):
        if self.in_planes == self.out_planes and self.stride == 1:
            self.running1(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.downsample(x)
        self.running2(out)
        return self.relu(out)
    
    def deploy(self, merge_bn=False):
        idconv1 = nn.Conv2d(self.in_planes, self.mid_planes, kernel_size=3, stride=self.stride, padding=1, bias=False).eval()
        idbn1=nn.BatchNorm2d(self.mid_planes).eval()
        
        nn.init.dirac_(idconv1.weight.data[:self.in_planes])
        bn_var_sqrt=torch.sqrt(self.running1.running_var + self.running1.eps)
        idbn1.weight.data[:self.in_planes]=bn_var_sqrt
        idbn1.bias.data[:self.in_planes]=self.running1.running_mean
        idbn1.running_mean.data[:self.in_planes]=self.running1.running_mean
        idbn1.running_var.data[:self.in_planes]=self.running1.running_var
        
        idconv1.weight.data[self.in_planes:]=self.conv1.weight.data
        idbn1.weight.data[self.in_planes:]=self.bn1.weight.data
        idbn1.bias.data[self.in_planes:]=self.bn1.bias.data
        idbn1.running_mean.data[self.in_planes:]=self.bn1.running_mean
        idbn1.running_var.data[self.in_planes:]=self.bn1.running_var
        
        idconv2 = nn.Conv2d(self.mid_planes, self.out_planes, kernel_size=3, stride=1, padding=1, bias=False).eval()
        idbn2=nn.BatchNorm2d(self.out_planes).eval()
        downsample_bias=0
        if self.in_planes==self.out_planes:
            nn.init.dirac_(idconv2.weight.data[:,:self.in_planes])
        else:
            idconv2.weight.data[:,:self.in_planes],downsample_bias=self.fuse(F.pad(self.downsample[0].weight.data, [1, 1, 1, 1]),self.downsample[1].running_mean,self.downsample[1].running_var,self.downsample[1].weight,self.downsample[1].bias,self.downsample[1].eps)

        idconv2.weight.data[:,self.in_planes:],bias=self.fuse(self.conv2.weight,self.bn2.running_mean,self.bn2.running_var,self.bn2.weight,self.bn2.bias,self.bn2.eps)
        
        bn_var_sqrt=torch.sqrt(self.running2.running_var + self.running2.eps)
        idbn2.weight.data=bn_var_sqrt
        idbn2.bias.data=self.running2.running_mean
        idbn2.running_mean.data=self.running2.running_mean+bias+downsample_bias
        idbn2.running_var.data=self.running2.running_var
        
        if merge_bn:
            return [torch.nn.utils.fuse_conv_bn_eval(idconv1,idbn1),self.relu,torch.nn.utils.fuse_conv_bn_eval(idconv2,idbn2),self.relu]
        else:
            return [idconv1,idbn1,self.relu,idconv2,idbn2,self.relu]


    def fuse(self,conv_w, bn_rm, bn_rv,bn_w,bn_b, eps):
        bn_var_rsqrt = torch.rsqrt(bn_rv + eps)
        conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
        conv_b = bn_rm * bn_var_rsqrt * bn_w-bn_b
        return conv_w,conv_b
  1. 残差模块的转换
    这里我们直接print原始的残差模块和转换后的残差模块。

原始残差模块:

(0): ResBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential()
      (running1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
      (running2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)
    )

转换后的RM模块

 (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU(inplace=True)
  (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (8): ReLU(inplace=True)

解释的图像
在这里插入图片描述

  • 将原来的conv2d(64, 64) ->conv2d(64, 128), 增加输出的channel数,增加的数量和输入的feature数一样。
  • 增加的kernel的权重满足dirac分布(只保留一个通道,其余为0)
  • 得到的结果等价于concat(x, conv(x)),相当于将输入特征和第一个conv的计算结果进行了通道拼接。
  • 第二个卷积的变换,是将输入通道增加输入的feature数,同时增加的kernel权重满足dirac分布。这一步的操作等价于残差模块的+

3. 优缺点

优点

  • 残差模块确实可以全部转换成卷积,bn,relu。感觉后续这些也可以合并
  • 等价之后方便裁剪,因为很多kernel权重为0

缺点

  • 转换为RM模块之后,kernel数增加了,虽然很多为0,但是做了很多无用计算。要在设备上实测推理速度。

reference

  1. RMNet推理去除残差结构让ResNet、MobileNet、RepVGG Great Again(必看必看)
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值