1. 引言
本文重点介绍BatchNorm
的定义和相关特性,并介绍了其详细实现和具体应用。希望可以帮助大家加深对其理解。
嗯嗯,闲话少说,我们直接开始吧!
2. 什么是BatchNorm?
BatchNorm是2015年提出的网络层,这个层具有以下特性:
-
易于训练:由于网络权重的分布随这一层的变化小得多,因此我们可以使用更高的学习率。我们在训练中收敛的方向没有那么不稳定,这样我们就可以更快地朝着loss收敛的方向前进。
-
提升正则化:尽管网络在每个epoch都会遇到相同的训练样本,但每个小批量的归一化是不同的,因此每次都会稍微改变其值。
-
提升精度:可能是由于前面两点的结合,论文提到他们获得了比当时最先进的结果更好的准确性。
3. BatchNorm是如何工作的?
BatchNorm
所做的是确保接收到的输入具有平均值0和标准偏差1。
本文中介绍的算法如下:
下面是我自己用pytorch进行的实现:
import numpy as np
import torch
from torch import nn
from torch.nn import Parameter
class BatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.gamma = Parameter(torch.Tensor(num_features))
self.beta = Parameter(torch.Tensor(num_features))
self.register_buffer("moving_avg", torch.zeros(num_features))
self.register_buffer("moving_var", torch.ones(num_features))
self.register_buffer("eps", torch.tensor(eps))
self.register_buffer("momentum", torch.tensor(momentum))
self._reset()
def _reset(self):
self.gamma.data.fill_(1)
self.beta.data.fill_(0)
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0)
self.moving_avg = self.moving_avg * momentum + mean * (1 - momentum)
self.moving_var = self.moving_var * momentum + var * (1 - momentum)
else:
mean = self.moving_avg
var = self.moving_var
x_norm = (x - mean) / (torch.sqrt(var + self.eps))
return x_norm * self.gamma + self.beta
这里对其进行补充说明如下:
- 我们在训练和推理过程中BatchNorm有不同的行为。在训练中,我们记录均值和方差的指数移动平均值,以供以后在推理时使用。其原因是,在训练期间处理批次时,我们可以获得输入随时间变化的均值和方差的更好估计,然后将其用于推理。在推理过程中使用输入批次的平均值和方差将不太准确,因为其大小可能比训练中使用的小得多,大数定律在这里发挥了作用。
4. 什么时候使用Batchnorm ?
这似乎总是有帮助的,所以没有理由不使用它。通常它出现在全连接层/卷积层和激活函数之间。但也有人认为,最好把它放在激活层之后。我找不到任何关于激活函数之后使用它的论文,所以最安全的选择是按照每个人的做法,在激活函数前使用它。
5. 一些技巧总结
列举下关于实际应用中BatchNorm的技巧总结如下:
- 我们知道,一个已经训练的网络包含用于训练它的数据集的移动平均值和方差,这可能是一个问题。在迁移学习期间,我们通常会冻结大部分层,如果不小心,BatchNorm层也会冻结,这意味着应用的移动平均值属于原始数据集,而不是新数据集。解冻BatchNorm层是一个好主意,将允许网络重新计算自己数据集上的移动平均值和方差。