Batch Normalization概念
批标准化
- 批:一批数据,通常为mini-batch
- 标准化:0均值,1方差
优点:
1、可以用更大的学习率,加速模型的收敛
2、可以不用精心设计权值初始化
3、可以不用dropout或者较小的dropout
4、可以不用L2或者较小的weight decay
5、可以不用LRN(local response normalization)
计算方式
affine transform 增强Capacity(容纳能力)
内部协变量转移
_BatchNorm
__init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True)
nn.BatchNorm1d
nn.BatchNorm2d
nn.BatchNorm3d
参数:
- num_features: 一个样本特征数量
- eps:分母修正项
- momentum: 指数加权平均估计当前mean/var
- affine: 是否需要affine transform
- track_running_stats: 是训练状态还是测试状态
主要属性:
- running_mean: 均值
- running_var: 方差
- weight: affine transform 中的gamma
- bias: affine transform 中的beta
训练:均值和方差采用指数加权平均计算
测试:当前统计值
running_mean = (1-momentum) * pre_running_mean + momentum * mean_t
running_var = (1-momentum) * pre_running_var + momentum * var_t