本文是 Speeding up model with fusing batch normalization and convolution - LearnML.Today的翻译。conv+bn融合主要是在推理阶段进行加速,BN在推理时无需更新参数,且推理过程满足Conv的计算公式,能合二为一。好处是加快了推理,在量化任务中,也提高了精度(在高精度先乘,相比转换为低精度再乘,减小了精度损失)。YOLOv5中使用了该技术。
这是量化和推理优化模型中常用技术。

今天我们将试着理解如何使我们的模型在推理上更快一些。
大量的网络使用 BN 来提高网络的泛化能力 。但是在推理阶段,Batch Normalization 被关闭,取代使用的是每个通道的均值 和方差
的近似值。最酷的是我们可以通过1x1卷积实现同样的行为。更好的是,我们可以把BN和前面的卷积合并。
Batch Normalization
设 x 是我们要 normalize 的网络中的一个信号(激活)。给定一组这样的信号 来自于在一个batch 中 处理不同的样本,每一个都被normalized 如下:

和
是在一个 batch 上计算的 均值和方差(mean and variance),
是数值稳定性的一个小常数,
是比例因子,
是 转换因子。在训练期间,
和
对于每个 batch 都被重新计算:

参数 和
和网络的其他参数一起从梯度下降慢慢地中学习。在测试期间,我们通常不会在图像的一个batch上运行网络。因此,前面提到公式中的
和
不能使用。我们使用在训练中通过指数移动平均( exponential moving average)计算它们的估计值。让我们标记它们的近似值为
和
。
目前,batch normalization 主要应用于卷积神经网络对图像的处理。在该设置中,输入特征图的每个通道都有均值和方差估计、比例和转换参数。我们将这些表示为:对于通道c:,
,
和
解决
实现冻结的Batch Normalization为一个1×1 Convolution
给定一个具有形状 顺序的特征图F, 为了得到 它的 normalized 版本
。使用上面的公式,我们需要计算每个空间位置
的
。

我们清晰的看到:这是 ,它可以实现为 一个 1×1 Convolution。甚至,因为BN经常在卷积层后,我们可以把卷积和BN融合为一个。
使用一个卷积层融合batch normalization
设, 和
- 是BN的参数
和
- 是在BN前面的卷积层的参数
- 卷积层的输入
- 输入层的通道数
k - filter 的size
的
部分被 reshaped 到一个shape 为
的向量
, 因此产生的公式:
![]()
因此,我们可以使用下面的参数通过一个单个卷积层替换卷积+BN 两层。
- filter weights:
- bias:
使用 PyTorch 实现:

import torch
import torchvision
def fuse(conv, bn):
fused = torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
bias=True
)
# setting weights
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps+bn.running_var)))
fused.weight.copy_( torch.mm(w_bn, w_conv).view(fused.weight.size()) )
# setting bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros( conv.weight.size(0) )
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(
torch.sqrt(bn.running_var + bn.eps)
)
fused.bias.copy_( b_conv + b_bn )
return fused
# Testing
# we need to turn off gradient calculation because we didn't write it
torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)
resnet18 = torchvision.models.resnet18(pretrained=True)
# removing all learning variables, etc
resnet18.eval()
model = torch.nn.Sequential(
resnet18.conv1,
resnet18.bn1
)
f1 = model.forward(x)
fused = fuse(model[0], model[1])
f2 = fused.forward(x)
d = (f1 - f2).mean().item()
print("error:",d)
其他参考资料:
【基础算法】六问透彻理解BN(Batch Normalization) - 知乎
7.5. 批量规范化 — 动手学深度学习 2.0.0-beta0 documentation
Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

1102

被折叠的 条评论
为什么被折叠?



