网络inference加速:Fuse Conv&BN

本文参考自ZJC师兄的知乎:链接

当前CNN卷积层的基本组成单元标配:Conv + BN +ReLU 三剑客。但其实在网络的推理阶段,可以将BN层的运算融合到Conv层中,减少运算量,加速推理。

本质上是修改了卷积核的参数,在不增加Conv层计算量的同时,略去了BN层的计算量。

公式推导如下:
公式推导

附一个代码实现:在这里插入图片描述

def fuse_conv_and_bn(conv, bn):
    # https://tehnokv.com/posts/fusing-batchnorm-and-conv/
    with torch.no_grad():
        # init
        fusedconv = torch.nn.Conv2d(conv.in_channels,
                                    conv.out_channels,
                                    kernel_size=conv.kernel_size,
                                    stride=conv.stride,
                                    padding=conv.padding,
                                    bias=True)

        # prepare filters
        w_conv = conv.weight.clone().view(conv.out_channels, -1)
        w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
        fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))

        # prepare spatial bias
        if conv.bias is not None:
            b_conv = conv.bias
        else:
            b_conv = torch.zeros(conv.weight.size(0))
        b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
        fusedconv.bias.copy_(b_conv + b_bn)

        return fusedconv

这里稍微解释一下代码内容:

  • 对于W的计算,其实只需要在原W基础上乘个系数,所以源码中这里将W拉伸为一个行向量组成的矩阵(每个行向量对应一个out_channel),并与以对应位置系数为元素的对角矩阵相乘,得到一个新行向量组成的矩阵,最后再恢复为原尺度即可;注意这里需要拉伸为矩阵是因为只有二维矩阵才能做.mm乘法,四维矩阵不能直接做乘法
  • 另外,源码这里对bias的计算,其实没有严格按照上面的推导公式做,而是将原公式中b的系数从‘μ/sqrt(…)’变成了1
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值