torch.nn.BatchNorm3d

本文详细介绍了PyTorch中nn.BatchNorm3d的功能及使用方法。该层应用于5D输入,可在训练过程中减少内部协变量偏移,提高深度网络训练速度。文中还提供了实例演示如何创建带及不带可学习参数的批量归一化层。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

CLASS torch.nn.BatchNorm3d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

正如论文 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 所描述的,在一个5D输入上应用Batch Normalization。

The mean and standard-deviation are calculated per-dimension over the mini-batches. \gamma\beta是可学习的大小为C(C是输入的大小)的参数向量。默认情况下,\gamma被设为1,\beta被设为0。通过有偏估计计算标准差,等同于torch.var(input, unbiased=False)。

默认情况下,在训练时,这层保持计算均值和方差的估计,这个估计被用于验证时的normalization。运行的估计以动量0.1被保持。

如果track_running_stats 被设置成 False ,这层不保持运行估计,而且在验证时,也使用批统计。

 

注意:这里的参数momentum跟优化器里用的momentum不是一个概念。从数学上来说,运行统计的更新规则是\hat x_{new} = (1-momentum)\times \hat x + momentum \times x_t,这里\hat x是估计的统计量,而且x_t是新的观察值。

 

因为Batch Normalization是在通道维度上做的,即在(N, D, H, W)上计算统计量,常用的术语是把它叫做Volumetric Batch Normalization或Spatio-temporal Batch Normalization。

Example:

>>> # With Learnable Parameters
>>> m = nn.BatchNorm3d(100)
>>> # Without Learnable Parameters
>>> m = nn.BatchNorm3d(100, affine=False)
>>> input = torch.randn(20, 100, 35, 45, 10)
>>> output = m(input)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值