Batch Normalization概念
Batch Normalization:批标准化
批:一批数据,通常为mini-batch
标准化:0均值,1方差
优点:
1. 可以用更大学习率,加速模型收敛
2. 可以不用精心设计权值初始化
3. 可以不用dropout或较小的dropout
4. 可以不用L2或者较小的weight decay
5. 可以不用LRN(local response normalization)
计算方式:
PyTorch的Batch Normalization 1d/2d/3d实现
_BatchNorm
• nn.BatchNorm1d
• nn.BatchNorm2d
• nn.BatchNorm3d
__init__(
self,
num_features, #一个样本特征数量(最重要)
eps=1e-5, #分母修正项
momentum=0.1, #指数加权平均估计当前mean/var
affine=True, #是否需要affine transform
track_running_stats=True#是训练状态,还是测试状态
)
主要属性:
running_mean:均值
running_var:方差
weight:affine transform中的gamma
bias: affine transform中的beta
running_mean = (1 - momentum) * pre_running_mean + momentum * mean_t
running_var = (1 - momentum) * pre_running_var + momentum * var_t
训练:均值和方差采用指数加权平均计算
测试:当前统计值
参考文献:《 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》