一、Batch Normalization概念
1.1 Batch Normalization概念
Batch Normalization:批标准化
- 批:一批数据,通常为mini-batch
- 标准化: 0均值, 1方差
优点:
- 可以用更大学习率,加速模型收敛
- 可以不用精心设计权值初始化
- 可以不用dropout或较小的dropout
- 可以不用L2或者较小的weight decay
- 可以不用LRN(local response normalization)
《 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift》
计算方式:
输入:一个mini-batch数据(m个),两个待学习的参数 γ , β \gamma,\beta γ,β
输出:
- 求取mini-batch数据的均值和方差
- 对mini-batch中的每个数据标准化, ϵ \epsilon ϵ是修正项,防止分母为0
- 对上一步数据进行affine transfrom,可理解为缩放和平移,增强Capacity
1.2 Internal Covariate Shift (ICS)
ICS:可以简单理解为数据尺度或分布的变化
由上图中的D(H1)=n*D(x)*D(W)=1可知,第一个隐藏层的输出等于上一层的输入的方差和二者之间权重的方差的连乘,所以如果数据的方差发生微小变化,那么随着网络的加深,这个变化会越来越明显,从而导致梯度消失或梯度爆炸
所以数据尺度或分布发生变化,则会导致模型难以训练
而Batch Normalization就是为了解决这个问题而推出来的
1.3 Batch Normalization应用
1.3.1 使用BN,可以不用权值初始化
# -*- coding: utf-8 -*-
import torch
import numpy as np
import torch.nn as nn
from tools.common_tools import set_seed
set_seed(1) # 设置随机种子
class MLP(nn.Module):
def __init__(self, neural_num, layers=100):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
self.neural_num = neural_num
def forward(self, x):
for (i, linear), bn in zip(enumerate(self.linears), self.bns):
x = linear(x)
x = bn(x) # 在激活函数之前使用BN层
x = torch.relu(x)
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break
print("layers:{}, std:{}".format(i, x.std().item()))
return x
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# method 1
# nn.init.normal_(m.weight.data, std=1) # normal: mean=0, std=1
# method 2 kaiming
nn.init.kaiming_normal_(m.weight.data)
neural_nums = 256
layer_nums = 100
batch_size = 16
net = MLP(neural_nums, layer_nums)
# net.initialize()
inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1
output = net(inputs)
print(output)
可以从上图看到,当使用了BN,不使用权值初始化,每层的标准差依然保持的很好
1.3.2 BN应用二分类模型
# -*- coding:utf-8 -*-
"""
@file name : bn_application.py
# @author : TingsongYu https://github.com/TingsongYu
@date : 2019-11-01
@brief : nn.BatchNorm使用
"""
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed
class LeNet_bn(nn.Module):
def __init__(self, classes):
super(LeNet_bn, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.bn1 = nn.BatchNorm2d(num_features=6)
self.conv2