一、BN的作用
原因:当网络很深的时候,如果初始输入数据很小,比如介于[0, 1],前向传播时,会导致数据越来越小,最后导致数据趋向于0。导致反向传播时,梯度可能会消失,使模型无法训练。如果输入数据很大时,前向传播使数据越来越大,反向传播求梯度时,梯度可能会爆炸,同样不利于训练。
二、BN优点
三、BN算法
四、PyTorch的BN
注:不管有多少个样本,一个样本有多少个特征维度(特征图个数)。BN都是在相同的特征维度上进行计算均值、方差、γ和β。
四个重要参数:
weight: 用于保存模型表达能力的γ
bias:用于保存模型表达能力的β
running_mean:特征维度对应的均值
running_std:特征维度对应的方差
五、验证代码
# -*- coding: utf-8 -*-
"""
@file name : bn_in_123_dim.py
@author : QuZhang
@date : 2021-1-1 20:51
@brief : bn的三种维度函数
"""
from tools.common_tools import set_seed
import torch
import torch.nn as nn
set_seed(1)
if __name__ == "__main__":
# ---------------- nn.BatchNorm1d ------------
# 一维的BN层:特征里最小的特征单元是1维
# flag = True
flag = False
if flag:
batch_size = 3 # 3个样本
num_features = 5 # 5个特征维度
momentum = 0.1 # 用于加权平均
features_shape = (1)
# 1D : 一个样本,一个特征维度
feature_map = torch.ones(features_shape) # 最小的特征单元
# print("feature_map: ", feature_map.shape)
# 2D : 一个样本,多个特征维度
# 在第一个维度进行扩展
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0) # (扩展后的值,扩展的维度)
# print("feature_maps: ", feature_maps)
# 3D : 多个样本,多个特征维度
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0) # 批量的数据
print("input data:\n{} shape is: {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1 # 初始化上一次的均值和方差
for i in range(2):
outputs = bn(feature_maps_bs)
# 使用BN计算所有特征维度的均值和方差
print("\niteration: {}, running mean: {}".format(i, bn.running_mean))
print("iteration: {}, running var: {}".format(i, bn.running_var))
# 手动计算第二个特征维度的均值和方差
mean_t, var_t = 2, 0 # 当前均值和方差
# 用当前均值和方差与之前的均值和方差指数加权平均得到新的方差和均值
running_mean = (1-momentum) * running_mean + momentum * mean_t
running_var = (1-momentum) * running_var + momentum * var_t
print("iteration:{}, 第二个特征的running mean: {} ".format(i, running_mean))
print("iteration:{}, 第二个特征的running var:{}".format(i, running_var))
print("outputs:\n", outputs.data)
# ---------------- nn.BatchNorm2d --------------
# 二维的BN层:特征里最小的特征单元是2维
flag = True
if flag:
batch_size = 3
num_features = 6 # 特征维度数
momentum = 0.1
features_shape = (2, 2) # 一个特征维度里的数据是2D
feature_map = torch.ones(features_shape) # 最小的特征单元 2D
feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)
feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)
print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))
bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)
running_mean, running_var = 0, 1
for i in range(2):
outputs = bn(feature_maps_bs)
# 验证BN在同一维度上计算
print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))
print("\niter:{}, weight.shape: {}".format(i, bn.weight.shape))
print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))