BN
- ICS
BN是为了解决深度学习中的 Internal Covariate Shift 问题及其影响,ICS产生的原因是由于参数更新带来的网络中每一层输入值分布的改变,并且随着网络层数的加深而变得更加严重,这就使得高层需要不断去重新适应底层的参数更新,因此我们可以通过固定每一层网络输入值的分布来对减缓ICS问题。
- 白化
在BN之前,白化(whitening)是一个重要的数据预处理步骤,包含两个目的:
1.使得所有特征具有相同的均值和方差,即同分布,PCA白化保证了所有特征分布均值为0,方差为1;而ZCA白化则保证了所有特征分布均值为0,方差相同。
2.消除特征之间的相关性,使特征独立。
白化主要有以下两个问题:
1.标准的白化操作代价高昂,特别是我们还希望白化操作是可微的,保证白化操作可以通过反向传播来更新梯度。
2.白化过程由于改变了网络每一层的分布,因而改变了网络层中本身数据的表达能力。底层网络学习到的参数信息会被白化操作丢失掉。
- BN
因此,以 BN 为代表的 Normalization 方法退而求其次,进行了简化的白化操作。基本思想是:尝试单独对每个特征进行normalizaiton就可以了,让每个特征都有均值为0,方差为1的分布。再加个线性变换操作,让这些数据再能够尽可能恢复本身的表达能力。
h
=
f
(
g
×
x
−
μ
θ
+
b
)
h = f(g\times\frac{x-\mu}{\theta}+b)
h=f(g×θx−μ+b)
x
^
=
x
−
μ
θ
\hat x=\frac{x-\mu}{\theta}
x^=θx−μ:得到均值为0,方差为1的标准分布
y
=
g
×
x
^
+
b
y=g\times\hat x+b
y=g×x^+b:最终得到的数据符合均值为b、方差为
g
2
g^2
g2的分布
当训练完成后,测试模型时,我们保留了每组mini-batch训练数据在网络中每一层的
μ
b
a
t
c
h
\mu_ {batch}
μbatch与
θ
b
a
t
c
h
\theta _ {batch}
θbatch。此时我们使用整个样本的统计量来对Test数据进行归一化,具体来说使用均值与方差的无偏估计。
E
(
μ
b
a
t
c
h
)
=
μ
t
e
s
t
E(\mu_ {batch}) = \mu_ {test}
E(μbatch)=μtest
E
(
θ
t
e
s
t
2
)
=
n
n
−
1
E
(
θ
b
a
t
c
h
2
)
E(\theta _ {test}^2)=\frac{n}{n-1}E(\theta _ {batch}^2)
E(θtest2)=n−1nE(θbatch2)
- 无偏估计
估计量的数学期望等于被估计参数的真实值,则称此此估计量为被估计参数的无偏估计。无偏估计的意义是:在多次重复下,它们的平均数接近所估计的参数真值。
X
ˉ
=
1
n
∑
i
=
1
N
X
i
\bar{X} = \frac{1}{n}\sum_{i=1}^N X_i
Xˉ=n1i=1∑NXi
E
(
X
ˉ
)
=
E
(
X
)
=
μ
E(\bar{X}) = E(X) = \mu
E(Xˉ)=E(X)=μ
样本均值
X
ˉ
\bar{X}
Xˉ的数学期望等于总体平均值
μ
\mu
μ,说明
X
ˉ
\bar{X}
Xˉ是总体平均值
μ
\mu
μ的无偏估计。
ζ
2
=
1
n
∑
i
=
1
N
(
X
i
−
X
ˉ
)
2
\zeta^2=\frac{1}{n} \sum_{i=1}^N (X_i-\bar{X})^2
ζ2=n1i=1∑N(Xi−Xˉ)2
E
(
ζ
2
)
=
n
−
1
n
θ
2
E(\zeta^2) =\frac{n-1}{n} \theta^2
E(ζ2)=nn−1θ2
E
(
S
2
)
=
n
n
−
1
E
(
ζ
2
)
=
θ
2
E(S^2)=\frac{n}{n-1}E(\zeta^2) = \theta^2
E(S2)=n−1nE(ζ2)=θ2
- 优点
1.BN使得网络中每层输入数据的分布相对稳定,加速模型学习速度
BN通过规范化与线性变换使得每一层网络的输入数据的均值与方差都在一定范围内,使得后一层网络不必不断去适应底层网络中输入的变化,从而实现了网络中层与层之间的解耦,允许每一层进行独立学习,有利于提高整个神经网络的学习速度。
2.BN使得模型对网络中的参数不那么敏感,简化调参过程,使得网络学习更加稳定
使用BN的网络将不会受到参数数值大小的影响,当权重
W
W
W按照常量
λ
\lambda
λ进行伸缩时,即:
B
N
(
λ
W
x
)
=
B
N
(
g
×
λ
W
x
−
λ
μ
λ
θ
+
b
)
=
B
N
(
W
x
)
BN(\lambda Wx)=BN(g\times\frac{\lambda Wx-\lambda\mu}{\lambda\theta}+b) = BN(Wx)
BN(λWx)=BN(g×λθλWx−λμ+b)=BN(Wx)
∂
B
N
(
λ
W
x
)
∂
x
=
∂
B
N
(
g
×
λ
W
x
−
λ
μ
λ
θ
+
b
)
∂
x
=
∂
B
N
(
W
x
)
∂
x
\frac{\partial BN(\lambda Wx)}{\partial x}=\frac{\partial BN(g\times\frac{\lambda Wx-\lambda\mu}{\lambda\theta}+b)}{\partial x} = \frac{\partial BN(Wx)}{\partial x}
∂x∂BN(λWx)=∂x∂BN(g×λθλWx−λμ+b)=∂x∂BN(Wx)
∂
B
N
(
λ
W
x
)
∂
λ
W
=
∂
B
N
(
g
×
λ
W
x
−
λ
μ
λ
θ
+
b
)
∂
λ
W
=
1
λ
∂
B
N
(
W
x
)
∂
W
\frac{\partial BN(\lambda Wx)}{\partial \lambda W}=\frac{\partial BN(g\times\frac{\lambda Wx-\lambda\mu}{\lambda\theta}+b)}{\partial \lambda W} = \frac{1}{\lambda}\frac{\partial BN(Wx)}{\partial W}
∂λW∂BN(λWx)=∂λW∂BN(g×λθλWx−λμ+b)=λ1∂W∂BN(Wx)
经过BN操作以后,权重的缩放值会被“抹去”,因此保证了输入数据分布稳定在一定范围内。另外,权重的缩放并不会影响到对
x
x
x的梯度计算;并且当权重越大时,即
λ
\lambda
λ越大,
1
/
λ
1/\lambda
1/λ越小,意味着权重
W
W
W的梯度反而越小,这样BN就保证了梯度不会依赖于参数的scale,使得参数的更新处在更加稳定的状态,相当于实现了参数正则化的效果,避免参数的大幅震荡,提高网络的泛化性能。
3.BN具有一定的正则化效果
在Batch Normalization中,由于我们使用mini-batch的均值与方差作为对整体训练样本均值与方差的估计,尽管每一个batch中的数据都是从总体样本中抽样得到,但不同mini-batch的均值与方差会有所不同,这就为网络的学习过程中增加了随机噪音,与Dropout通过关闭神经元给网络训练带来噪音类似,在一定程度上对模型起到了正则化的效果。
4.BN允许网络使用饱和性激活函数(例如sigmoid,tanh等),缓解梯度消失问题
在不使用BN层的时候,由于网络的深度与复杂性,很容易使得底层网络变化累积到上层网络中,导致模型的训练很容易进入到激活函数的梯度饱和区;通过normalize操作可以让激活函数的输入数据落在梯度非饱和区,缓解梯度消失的问题。
另外,原作者通过也证明了网络加入BN后,可以丢弃Dropout,模型也同样具有很好的泛化效果。
train和eval
- train模式
running_mean = (1 - momentum) * mean_old + momentum * mean_new
running_var = (1 - momentum) * var_old + momentum * var_new - eval模式
running_mean = mean_old
running_var =val_old
BN层与Conv层的融合
- BN层的计算公式如下,推理的时候,参数都是已知
- 卷积计算公式如下
- 融合结果如下
对于一个C * H * W的特征图,经过卷积后,每一个通道的feature map会对应不同的BN层参数,因此可以写成对角线的二维向量格式,方便运算
def fuse_conv_and_bn(conv, bn):
# Fuse Conv2d() and BatchNorm2d() layers https://tehnokv.com/posts/fusing-batchnorm-and-conv/
fusedconv = nn.Conv2d(conv.in_channels, // 3
conv.out_channels, // 16
kernel_size=conv.kernel_size, // (6, 6)
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True).requires_grad_(False).to(conv.weight.device)
# Prepare filters
'''
conv.weight.shape: torch.Size([16, 3, 6, 6])
w_conv.shape: torch.Size([16, 108])
bn.weight.shape: torch.Size([16])
torch.diag: 返回对角线格式二维张量
w_bn.shape: torch.Size([16, 16])
fusedconv.weight.shape: torch.Size([16, 3, 6, 6])
'''
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.shape))
# Prepare spatial bias
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
# fusedconv.bias.shape: torch.Size([16])
fusedconv.bias.copy_(torch.mm(w_bn, b_conv.reshape(-1, 1)).reshape(-1) + b_bn)
return fusedconv