(本文首发于公众号)
上一篇文章介绍了如何把 BatchNorm 和 ReLU 合并到 Conv 中,这篇文章会介绍具体的代码实现。本文相关代码都可以在 github 上找到。
Folding BN
回顾一下前文把 BN 合并到 Conv 中的公式:
KaTeX parse error: No such environment: align at position 8: \begin{̲a̲l̲i̲g̲n̲}̲ y_{bn}&=\frac{…
其中, x x x 是卷积层的输入, w w w、 b b b 分别是 Conv 的参数 weight 和 bias, γ \gamma γ、 β \beta β 是 BN 层的参数。
对于 BN 的合并,首先,我们需要熟悉 pytorch 中的 BatchNorm2d
模块。
pytorch 中的 BatchNorm2d
针对 feature map 的每一个 channel 都会计算一个均值和方差,所以公式 (1) 需要对 weight 和 bias 进行 channel wise 的计算。另外,BatchNorm2d
中有一个布尔变量 affine
,当该变量为 true 的时候,(1) 式中的 γ \gamma γ 和 β \beta β 就是可学习的, BatchNorm2d
会中有两个变量:weight
和 bias
,来分别存放这两个参数。而当 affine
为 false 的时候,就直接默认 γ = 1 \gamma=1 γ=1, β = 0 \beta=0 β=0,相当于 BN 中没有可学习的参数。默认情况下,我们都设置 affine=True
。
我们沿用之前的代码,先定义一个 QConvBNReLU
模块:
class QConvBNReLU(QModule):
def __init__(self, conv_module, bn_module, qi=True, qo=True, num_bits=8):
super(QConvBNReLU, self).__init__(qi=qi, qo=qo, num_bits=num_bits)
self.num_bits = num_bits
self.conv_module = conv_module
self