这里写目录标题
一、为什么要使用批量归一化:
二、什么是批量归一化:
批量归一化通过固定batch中的均值和方差,然后可以学习出该batch适合的偏移B和缩放Y,最后使用这两个参数对batch中的数据进行缩放。
批量归一化一般可以加快模型的收敛速度,但不会提高模型的精度。
三、代码实现:
import torch
from torch import nn
from d2l import torch as d2l
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
if not torch.is_grad_enabled():
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
mean = X.mean(dim=0)
var = ((X - mean)**2).mean(dim=0)
else:
mean = X.mean(dim=(0, 2, 3), keepdim=True)
var = ((X - mean)**2).mean(dim=(0, 2, 3), keepdim=True)
X_hat = (X - mean) / torch.sqrt(var + eps)
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta
return Y, moving_mean.data, moving_var.data
# 批量归一化层
class BatchNorm(nn.Module):
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape)
def forward(self, X):
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean, self.moving_var,
eps=1e-5, momentum=0.9)
return Y
# 模型中加入批量归一化加速参数收敛
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16,
kernel_size=5), BatchNorm(16, num_dims=4),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.Linear(16 * 4 * 4, 120),
BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2),
nn.Sigmoid(), nn.Linear(84, 10))
# 直接使用Pytorch库的批量归一化
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16),
nn.Sigmoid(), nn.MaxPool2d(kernel_size=2, stride=2),
nn.Flatten(), nn.Linear(256, 120), nn.BatchNorm1d(120),
nn.Sigmoid(), nn.Linear(120, 84), nn.BatchNorm1d(84),
nn.Sigmoid(), nn.Linear(84, 10))