原文:https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
\qquad 背景:组里面一直在用BN融合进行模型的加速,之前大致了解了原理,但是细节完全没印象,所以这次弄懂并记录。
一、优势
\qquad 推理的时候,进行BN融合可以把Conv+BN融合成一个Conv,减少运算量,减少模型参数量,减少访存,总之是可以加速,而且没有模型的精度损失(除非有一点点的计算精度差异)
二、原理
\qquad
见文章, 说的很清楚了。基本上就是数据上的等式变换。可以前向时候的BN转化为
1
×
1
1 \times 1
1×1的卷积
三、代码
import torch
import torch.nn as nn
import torchvision
def fuse_conv_and_bn(conv, bn):
#
# init
fusedconv = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
#
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fusedconv.weight.copy_( torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
#
# prepare spatial bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_( b_conv + b_bn )
#
# we're done
return fusedconv
if __name__ =="__main__":
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
print(rn18)
net = torch.nn.Sequential(
rn18.conv1,
rn18.bn1
)
y1 = net.forward(x)
fusedconv = fuse_conv_and_bn(net[0], net[1])
y2 = fusedconv.forward(x)
d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
Reference
https://nenadmarkus.com/p/fusing-batchnorm-and-conv/