模拟BN 的 forward
import torch
import torch.nn as nn
import torch.nn.modules.batchnorm
# 创建随机输入
def create_inputs():
return torch.randn(8, 3, 20, 20)
# 以 BatchNorm2d 为例
# mean_val, var_val 不为None时,不对输入进行统计,而直接用传进来的均值、方差
def dummy_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None):
if mean_val is None:
mean_val = x.mean([0, 2, 3])
if var_val is None:
# 这里需要注意,torch.var 默认算无偏估计,因此需要手动设置unbiased=False
var_val = x.var([0, 2, 3], unbiased=False)
x = x - mean_val[None, ..., None, None]
x = x / torch.sqrt(var_val[None, ..., None, None] + eps)
x = x * bn_weight[..., None, None] + bn_bias[..., None, None]
return mean_val, var_val, x
值得注意的是 在x。var的地方,默认使用了无偏估计,我们需要手动的取消掉,但是在BN中的running_var在动量更新的时候却使用的是无偏估计。即 * n / (n - 1)
running_mean、running_var 的更新
running_mean = torch.zeros(3)
running_var = torch.ones_like(running_mean)
momentum = 0.1 # 这也是BN初始化时momentum默认值
bn_layer = nn.BatchNorm2d(num_features=3, momentum=momentum)
# 模拟 forward 10 次
for t in range(10):
inputs = create_inputs()
bn_outputs = bn_layer(inputs)
inputs_mean, inputs_var, _ = dummy_bn_forward(
inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
)
n = inputs.numel() / inputs.size(1)
# 更新 running_var 和 running_mean
running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1)
running_mean = running_mean * (1 - momentum) + momentum * inputs_mean
assert torch.allclose(running_var, bn_layer.running_var)
assert torch.allclose(running_mean, bn_layer.running_mean)
print(f'bn_layer running_mean is {bn_layer.running_mean}')
print(f'dummy bn running_mean is {running_mean}')
print(f'bn_layer running_var is {bn_layer.running_var}')
print(f'dummy bn running_var is {running_var}')
eval 模式
上面验证的都是 train 模式下 BN 的表现,eval 模式有几个重要的参数。
track_running_stats
默认为True
,train 模式下统计running_mean
和running_var
,eval 模式下用统计数据作为 μ \mu μ和 σ \sigma σ.设置为False
时,eval模式直接计算输入的均值和方差。running_mean
、running_var
:train 模式下的统计量。
也就是说,BN.training
并不是决定 BN 行为的唯一参数。满足BN.training or not BN.track_running_stats
就会直接计算输入数据的均值方差,否则用统计量代替。