1. 先看用法
import torch
import torch.nn as nn
input = torch.randn(1, 2, 3, 4)
print(input)
bn = nn.BatchNorm2d(num_features=2)
res = bn(input)
print(res)
2. 作用
其实就是将一批feature map进行标准化处理
。我们都学过正态分布的表达,
x
ˉ
i
=
x
−
μ
σ
2
{\bar x_i} = \frac{{x - \mu }}{{{\sigma ^2}}}
xˉi=σ2x−μ。
同理,看看官方的表达式:
y
=
x
−
E
[
x
]
V
a
r
[
x
]
+
ϵ
∗
γ
+
β
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
y=Var[x]+ϵx−E[x]∗γ+β
3. 使用规则
计算的就是各个维度上的标准化,注意维度之间的对应规则