CLASS torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
正如论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 所描述的,在一个5D输入上应用Batch Normalization。
The mean and standard-deviation are calculated per-dimension over the mini-batches. 和
是可学习的大小为C(C是输入的大小)的参数向量。默认情况下,
被设为1,
被设为0。通过有偏估计计算标准差,等同于torch.var(input, unbiased=False)。
默认情况下,在训练时,这层保持计算均值和方差的估计,这个估计被用于验证时的normalization。运行的估计以动量0.1被保持。
如果track_running_stats
被设置成 False ,这层不保持运行估计,而且在验证时,也使用批统计。
注意:这里的参数momentum跟优化器里用的momentum不是一个概念。从数学上来说,运行统计的更新规则是,这里
是估计的统计量,而且
是新的观察值。
因为Batch Normalization是在通道维度上做的,即在(N, D, H, W)上计算统计量,常用的术语是把它叫做Volumetric Batch Normalization或Spatio-temporal Batch Normalization。
Example:
>>> # With Learnable Parameters
>>> m = nn.BatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)