BatchNorm

模拟BN 的 forward

import torch
import torch.nn as nn
import torch.nn.modules.batchnorm

# 创建随机输入
def create_inputs():
    return torch.randn(8, 3, 20, 20)

# 以 BatchNorm2d 为例
# mean_val, var_val 不为None时,不对输入进行统计,而直接用传进来的均值、方差
def dummy_bn_forward(x, bn_weight, bn_bias, eps, mean_val=None, var_val=None):
    if mean_val is None:
        mean_val = x.mean([0, 2, 3])
    if var_val is None:
        # 这里需要注意,torch.var 默认算无偏估计,因此需要手动设置unbiased=False
        var_val = x.var([0, 2, 3], unbiased=False)

    x = x - mean_val[None, ..., None, None]
    x = x / torch.sqrt(var_val[None, ..., None, None] + eps)
    x = x * bn_weight[..., None, None] + bn_bias[..., None, None]
    return mean_val, var_val, x

值得注意的是 在x。var的地方,默认使用了无偏估计,我们需要手动的取消掉,但是在BN中的running_var在动量更新的时候却使用的是无偏估计。即 * n / (n - 1)

running_mean、running_var 的更新

running_mean = torch.zeros(3)
running_var = torch.ones_like(running_mean)
momentum = 0.1 # 这也是BN初始化时momentum默认值
bn_layer = nn.BatchNorm2d(num_features=3, momentum=momentum)

# 模拟 forward 10 次
for t in range(10):
    inputs = create_inputs()
    bn_outputs = bn_layer(inputs)
    inputs_mean, inputs_var, _ = dummy_bn_forward(
        inputs, bn_layer.weight, bn_layer.bias, bn_layer.eps
    )
    n = inputs.numel() / inputs.size(1)
    # 更新 running_var 和 running_mean
    running_var = running_var * (1 - momentum) + momentum * inputs_var * n / (n - 1)
    running_mean = running_mean * (1 - momentum) + momentum * inputs_mean

assert torch.allclose(running_var, bn_layer.running_var)
assert torch.allclose(running_mean, bn_layer.running_mean)
print(f'bn_layer running_mean is {bn_layer.running_mean}')
print(f'dummy bn running_mean is {running_mean}')
print(f'bn_layer running_var is {bn_layer.running_var}')
print(f'dummy bn running_var is {running_var}')

eval 模式

上面验证的都是 train 模式下 BN 的表现,eval 模式有几个重要的参数。

  • track_running_stats默认为True,train 模式下统计running_meanrunning_var,eval 模式下用统计数据作为 μ \mu μ σ \sigma σ.设置为False时,eval模式直接计算输入的均值和方差。
  • running_meanrunning_var:train 模式下的统计量。

也就是说,BN.training 并不是决定 BN 行为的唯一参数。满足BN.training or not BN.track_running_stats就会直接计算输入数据的均值方差,否则用统计量代替。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值