batch normalization(批量归一化)
BN在深度网络训练过程中是非常好用的trick,而之前只是大概知道它的作用,很多细节并不清楚,因此希望用这篇文章彻底解决揭开BN的面纱。
BN层的由来与概念
- 讲解BN之前,我们需要了解BN是怎么被提出的。在机器学习领域,数据分布是很重要的概念。如果训练集和测试集的分布很不相同,那么在训练集上训练好的模型,在测试集上应该不奏效(比如用ImageNet训练的分类网络去在灰度医学图像上finetune再测试,效果应该不好)。对于神经网络来说,如果每一层的数据分布都不一样,后一层的网络则需要去学习适应前一层的数据分布,这相当于去做了domain的adaptation,无疑增加了训练难度,尤其是网络越来越深的情况。
- 实际上,确实如此,不同层的输出的分布是有差异的。BN的那篇论文中指出,不同层的数据分布会往激活函数的上限或者下限偏移。论文称这种偏移为internal Covariate Shift,internal指的是网络内部。神经网络一旦训练起来,那么参数就要发生更新,除了输入层的数据外(因为输入层数据,我们已经人为的为每个样本归一化),后面网络每一层的输入数据分布是一直在发生变化的,因为在训练的时候,前面层训练参数的更新将导致后面层输入数据分布的变化。以网络第二层为例:网络的第二层输入,是由第一层的参数和input计算得到的,而第一层的参数在整个训练过程中一直在变化,因此必然会引起后面每一层输入数据分布的改变, 第一层输出变化了,势必会引起第二层输入分布的改变,模型拟合的效果就会变差,也会影响模型收敛的速度(例如我原本的参数是拟合分布A的,然后下一轮更新的时候,样本都是来自分布B的,对于这组参数来说,这些样本就会很陌生)
- BN就是为了解决偏移的,解决的方式也很简单,就是让每一层的分布都normalize到标准高斯分布。(BN是根据划分数据的集合去做Normalization,不同的划分方式也就出现了不同的Normalization,如GN,LN,IN)
BN层的算法详解
Batch Normalization(BN)层用于在深层神经网络中,对每层的输入进行归一化,加速神经网络的收敛并提高泛化效果。BN层包含两个步骤:
1)对每个mini-batch的输入进行归一化
2)通过可学习的缩放和平移参数,来学习输入的均值和标准差,以增加模型的表达能力。
对于这两个可学习的参数解释如下:
我们前面提到了,前面的层引起了数据分布的变化,这时候可能有一种思路是说:在每一层输入的时候,再加一个预处理就好。比如归一化到均值为0,方差为1,然后再输入进行学习。基本思路是这样的,然而实际上没有这么简单,如果我们只是使用简单的归一化方式: x ~ i = x i − μ B σ B 2 + ε \tilde{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \varepsilon}} x~i=σB2+εxi−μB 对某一层的输入数据做归一化,然后送入网络的下一层,这样是会影响到本层网络所学习的特征的,比如网络中学习到的数据本来大部分分布在0的右边,经过RELU激活函数以后大部分会被激活,如果直接强制归一化,那么就会有大多数的数据无法激活了,这样学习到的特征不就被破坏掉了么?论文中对上面的方法做了一些改进:变换重构,引入了可以学习的两个参数: y i = γ x ~ i + β y_i = \gamma \tilde{x}_i + \beta yi=γx~i+β 这样的时候可以恢复出原始的某一层学习到的特征的,因此我们引入这个可以学习的参数使得我们的网络可以恢复出原始网络所要学习的特征分布。
移动平均更新均值和方差
- 对于归一化的均值和标准差,BN层通过运行均值和方差来近似计算它们。在训练过程中,每个mini-batch的均值和方差会被计算出来并用于归一化。而在测试阶段,使用整个训练集的均值和方差来估计测试集上的均值和方差,然后使用这些值对测试集进行归一化。这样做会导致BN层的均值和方差的估计可能不准确。
- 为了解决这个问题,BN层使用一种叫做‘移动平均’或‘指数平均’的技巧来估计整个训练集的均值和方差。具体而言,它通过不断更新均值和方差的指数滑动平均值来估计它们的整体的均值和标准差。这个方法可以保证在前向传播时使用的均值和标准差是全局统计量的稳定估计,有效地增强了模型的泛化能力。因此,移动平均的方式更新均值和方差是BN层中的一项重要的技巧。
样例详解
- 假设输入的数据维度为(64, 3, 256, 256)分别对应N,C,H,W。
- 首先计算数据各个通道的均值和方差,数据的通道数为3,所有会分别得到3个均值和方差,在训练阶段,我们利用移动平均来更新我们后续要用的均值和方差,在推理测试阶段,我们利用训练集上稳定估计的均值和方差。
- 利用移动平均的方式进行均值和方差的更新(训练过程,测试推理过程略过)
- 对输入数据进行归一化(注意数据维度,利用python矩阵乘法的广播机制)
- 使用可学习的缩放、偏移参数
python代码
import torch
import torchvision
import cv2
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
class BatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super(BatchNorm, self).__init__()
self.num_feautre = num_features
self.eps = eps
self.momentum = momentum
# 可学习参数
self.weight = nn.Parameter(torch.zeros(size=(num_features,)))
self.bias = nn.Parameter(torch.zeros(size=(num_features,)))
# 非可学习的缓存参数
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.zeros(num_features))
def forward(self, x):
if self.training:
# 计算输入均值和方差
mean = x.mean(dim=(0, 2, 3), keepdim=True)
var = x.var(dim=(0, 2, 3), keepdim=True)
# 更新缓存值
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()
else:
# 使用缓存值进行计算
mean = self.running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
var = self.running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
# 归一化操作
x = (x - mean) / torch.sqrt(var + self.eps)
# 使用缩放和平移操作
x = x * self.weight.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + \
self.bias.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
return x
def transformer(img):
"""
:param img:
:return:
"""
transform = transforms.Compose([
transforms.ToTensor(),
# transforms.Resize((224, 224)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(img)
if __name__ == '__main__':
image = cv2.imread('data/test_data.png')
H, W, C = image.shape
image = transformer(image).reshape(1, C, H, W).repeat(4, 1, 1, 1)
# torch BN
bn_torch = nn.BatchNorm2d(momentum=0.01, eps=0.001, num_features=3)
out_torch_bn = bn_torch(image)
# my BN
BN = BatchNorm(momentum=0.01, eps=0.001, num_features=C)
BN.weight = bn_torch.weight
BN.bias = bn_torch.bias
out_my_bn = BN(image)
print(torch.allclose(out_torch_bn, out_my_bn, rtol=1e-03, atol=1e-05))
最后比较了一下手动实现和torch自带BN包的结果,没错(●’◡’●)