Batch Normalization
加快模型收敛速度,并不会对模型准确率有提升。
模型收敛速度与初始化方式和Batch Normalization有关。
Xavier初始化配合数据预处理时的Normalize操作(将图片像素的分布变为均值为0,方差为1,也叫白化操作),可以保证最开始模型训练的时候
1)各层激活值的方差为1。
2)损失函数关于各层激活值的梯度的方差为1。
这两点也称Glorot条件。满足Glorot条件,可以使得损失函数关于参数的梯度保持稳定,加快模型收敛。
但在模型训练时,各层激活值的方差可能会变化,不再为1,变化为多少,太复杂了,人是不能计算出来的。
但我们可以看到这个重要的思想,就是让各层激活值尽量稳定。稳定即限制它们的均值和方差。
既然不知道要限制它们的方差和均值为多少,就可以去学习它们。
如上图所示的公式,各层激活值的方差与Var(zi)有关,Var(zi)的值如上公式所计算。
我们经过bn层学习后的Var(h),可以让Var(zi)保持一定的值。
所以bn层的思想就是这么来的:以全连接网络为例,做完一个全连接层后,我们可以加一个bn层,这个bn层将这个全连接网络的输出(假设有n个结点,这n个结点论文中说就可以看为n个特征,满足n个不同的分布,对应n个(mean,std))。要学的就是这n个(mean,std)。
通过这个bn层,将一个小批量中的所有数据都变为服从(mean,std)的分布,就保证了各层激活值的相对稳定,即Glorot条件中的1),有助于加快模型收敛。
具体做法:
bn层训练和测试时的操作不同:
训练时
假设B为某一小批量,先按下面的式子计算出它们的均值和标准差:
再按如下式子对xi进行缩放,将其分布的均值变为β,方差变为γ。其中γ和β是bn层学习到的参数。
到此就完成了训练时bn层的前向传播。
测试时
测试时的μB和σB采用训练时各小批量μB和σB的滑动平均,具体做法类似带动量的SGD。
γ和β用训练时学习好的参数即可。
卷积层的bn:
全连接网络按特征做bn。卷积层,考虑1*1的卷积,从通道间的角度来看,就类似一个全连接层。所以卷积间的bn是按通道做的。卷积结果的通道数就是特征数,每个特征的样本数就是该通道上的每个像素点,由于一批大小不止1,样本数就是一个小批量上每张图某个特定通道上的所有样本点。
网上的实现:
def batch_norm(X,gamma,beta,moving_mean,moving_var,eps,momentum):
#gamma,beta:要学习的参数
if torch.is_grad_enabled():
# 训练模式
assert len(X.shape) in (2,4)
if len(X.shape) == 2:
# 全连接层
mean = X.mean(dim=0) # 对featuer维做平均
var = ((X - mean)**2).mean(dim=0) #方差
else:
# 卷积层
mean = X.mean(dim=(0,2,3),keepdim=True)
var = ((X - mean)**2).mean(dim=(0,2,3),keepdim=True)
X_hat = (X-mean) / torch.sqrt(var + eps) #标准化~(0,1)
moving_mean = momentum * moving_mean + (1 - momentum) * mean
moving_var = momentum * moving_var + (1 - momentum) * var
else:
# test模式
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
Y = X_hat * gamma + beta #转化为~(beta,gamma)
return Y,moving_mean.data,moving_var.data
class BN(nn.Module):
def __init__(self,num_dims,feature_nums):
super(BN, self).__init__()
if num_dims == 2:
shape = (1,feature_nums)
elif num_dims == 4:
shape = (1,feature_nums,1,1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.ones(shape)
self.moving_var = torch.zeros(shape)
def __call__(self,X):
return self.forward(X)
def forward(self,X):
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
Y,self.moving_mean,self.moving_var = batch_norm(X,self.gamma,self.beta,self.moving_mean,self.moving_var,epos=1e-5,momentum=0.9)
return Y