为什么要Normalization
Internal Covariate Shift (ICS):数据尺度/分布异常,导致训练困难
常见的Normalization--BN、LN、IN and GN
常见的Normalization
1. Batch Normalization(BN)
2. Layer Normalization(LN)
3. Instance Normalization(IN)
4. Group Normalization(GN)
相同点:
区别:均值和方差求取方式不同
Layer Normalization
起因:BN不适用于变长的网络,如RNN
思路:逐层计算均值和方差
注意事项:
1. 不再有running_mean和running_var
2. gamma和beta为逐元素的
nn.LayerNorm(
normalized_shape, #该层特征形状
eps=1e-05, #分母修正项
elementwise_affine=True #是否需要affine transform
)
Instance Normalization
起因:BN在图像生成(Image Generation)中不适用
思路:逐Instance(channel)计算均值和方差
nn.InstanceNorm2d(
num_features, #一个样本特征数量(最重要)
eps=1e-05, #分母修正项
momentum=0.1, #指数加权平均估计当前mean/var
affine=False, #是否需要affine transform
track_running_stats=False#是训练状态,还是测试状态
)
Group Normalization
起因:小batch样本中,BN估计的值不准
思路:数据不够,通道来凑
注意事项:
1. 不再有running_mean和running_var
2. gamma和beta为逐通道(channel)的
应用场景:大模型(小batch size)任务
nn.GroupNorm(
num_groups, #分组数
num_channels, #通道数(特征数)
eps=1e-05, #分母修正项
affine=True#是否需要affine transform
)
Normalization小结
BN、LN、IN和GN都是为了克服Internal Covariate Shift (ICS)