MMDetection3D代码学习笔记——fuse-conv-bn的作用
fuse-conv-bn在mmdetection3d中的参数设置代码
parser.add_argument(
'--fuse-conv-bn',
action='store_true',
help='Whether to fuse conv and bn, this will slightly increase'
'the inference speed')
由help中的说明可以看出fuse-conv-bn的主要作用是加速模型的推理速度(increase the inference speed)
fuse-conv-bn为什么能够加快模型的推理速度
原因:当前CNN卷积层的基本组成单元为:Conv+BN+ReLu三剑客,这几乎成为标配。但其实在网络的推理阶段,可以将BN层的运算融合到Conv层中,减少运算量,加速推理。本质上是修改了卷积核的参数,在不增加Conv层计算量的同时,略去了BN层的计算量。公式推导如下。
附一个代码实现:
def fuse_conv_and_bn(conv, bn):
# https://tehnokv.com/posts/fusing-batchnorm-and-conv/
with torch.no_grad():
# init
fusedconv = torch.nn.Conv2d(conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True)
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
# prepare spatial bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0))
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(b_conv + b_bn)
return fusedconv
这里稍微解释一下代码内容:
-
对于W的计算,其实只需要在原W基础上乘个系数,所以源码中这里将W拉伸为一个行向量组成的矩阵(每个行向量对应一个out_channel),并与以对应位置系数为元素的对角矩阵相乘,得到一个新行向量组成的矩阵,最后再恢复为原尺度即可;注意这里需要拉伸为矩阵是因为只有二维矩阵才能做.mm乘法,四维矩阵不能直接做乘法
-
另外,源码这里对bias的计算,其实没有严格按照上面的推导公式做,而是将原公式中b的系数从‘μ/sqrt(…)’变成了1