批量归一化
损失出现在最后,所以后面的层训练比较快,而数据在最底部,则:
- 底部的层训练较慢
- 底部层一变化,所有都会跟着变化
- 最后的层需要重新学习多次
最后导致收敛变慢。
或许我们可以通过固定输出和梯度的特定分布,即均值和方差在一定范围内,来进行优化,以提高数据和损失的稳定性
1.批量归一化
固定小批量里面的均值和方差
μ
B
=
1
∣
B
∣
∑
i
∈
B
x
i
σ
B
2
=
1
∣
B
∣
∑
i
∈
B
(
x
i
−
μ
B
)
2
+
ϵ
(
加一个小数值避免为
0
)
\mu_B=\frac {1}{|B|}\sum_{i\in B} x_i\\ \sigma^2_B = \frac {1}{|B|}\sum_{i\in B}(x_i -\mu _B)^2 +\epsilon (加一个小数值避免为0)
μB=∣B∣1i∈B∑xiσB2=∣B∣1i∈B∑(xi−μB)2+ϵ(加一个小数值避免为0)
再做额外的调整(可学习的参数
γ
(
方差
)
,
β
(
均值
)
\gamma(方差),\beta (均值)
γ(方差),β(均值))
x
i
+
1
=
B
N
(
x
i
)
=
γ
x
i
−
μ
b
σ
B
+
β
x_{i+1} = BN(x_i) = \gamma\frac{x_i -\mu_b}{\sigma _B}+\beta
xi+1=BN(xi)=γσBxi−μb+β
2.批量归一化层
可学习的参数 γ , β \gamma,\beta γ,β
作用在全连接和卷积层的输出上,激活函数前;或全连接层和卷积层输入上
对全连接层,作用再特征维;对卷积层,作用在通道维
批量归一化是线性变换
2.1 全连接层
通常,我们将批量规范化层置于全连接层中的仿射变换和激活函数之间。 设全连接层的输入为x,权重参数和偏置参数分别为𝑊和𝑏,激活函数为𝜙,批量规范化的运算符为BN。 那么,使用批量规范化的全连接层的输出的计算详情如下:
h
=
ϕ
(
B
N
(
W
x
+
b
)
)
h=\phi (BN(Wx+b))
h=ϕ(BN(Wx+b))
2.2 卷积层
对于卷积层,我们可以在卷积层之后和非线性激活函数之前应用批量规范化。
当卷积有多个输出通道时,我们需要对这些通道的“每个”输出执行批量规范化,每个通道都有自己的拉伸(scale)和偏移(shift)参数,这两个参数都是标量。
假设我们的小批量包含𝑚个样本,并且对于每个通道,卷积的输出具有高度𝑝和宽度𝑞。 那么对于卷积层,我们在每个输出通道的𝑚⋅𝑝⋅𝑞个元素上同时执行每个批量规范化。
因此,在计算平均值和方差时,我们会收集所有空间位置的值,然后在给定通道内应用相同的均值和方差,以便在每个空间位置对值进行规范化。
批量归一化到底在做什么
可能是通过在每个小批量里加入噪音来控制模型复杂度,因此没必要和丢弃法混合使用(也只是可能)
μ B 、 σ B \mu_B、\sigma_B μB、σB是每一批次的均值和方差,其实每一批次都不一样,比较随机。
批量归一化可以加速收敛速度,但一般不改变模型精度。学习率就可以比较大了
3.代码实现
import torch
from torch import nn
from d2l import torch as d2l
'''从零实现'''
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
# moving_mean和moving_var 可近似认为是全局上的均值和方差,eps是方差小数值
# momentum 用于更新均值和方差,通常是0.9或一个固定的数字
# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式
if not torch.is_grad_enabled():
# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差
# 预测模式一般只有一个样本,所以需要用全局的均值和方差
X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
else:
assert len(X.shape) in (2, 4)
if len(X.shape) == 2:
# 使用全连接层的情况,计算特征维上的均值和方差
mean = X.mean(dim=0) # 按行求均值 1*n的向量
var = ((X - mean) ** 2).mean(dim=0) # 按行求方差, 1*n的向量
else:
# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。
# 这里我们需要保持X的形状以便后面可以做广播运算
mean = X.mean(dim=(0, 2, 3), keepdim=True) # dim:0批量大小,1输入输出通道,2高,3宽
# 则需要求出没一行的均值,最终是1 * n * 1 *1 形状
var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True) # 同理
# 训练模式下,用当前的均值和方差做标准化
X_hat = (X - mean) / torch.sqrt(var + eps)
# 更新移动平均的均值和方差,动量更新,i影响i+1
moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
moving_var = momentum * moving_var + (1.0 - momentum) * var
Y = gamma * X_hat + beta # 缩放和移位
return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):
# 批量归一化层
# num_features:完全连接层的输出数量或卷积层的输出通道数。
# num_dims:2表示完全连接层,4表示卷积层
def __init__(self, num_features, num_dims):
super().__init__()
if num_dims == 2:
shape = (1, num_features)
else:
shape = (1, num_features, 1, 1)
# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0
self.gamma = nn.Parameter(torch.ones(shape)) # gamma不能初始化为0,不然一乘全是0没办法学习了
self.beta = nn.Parameter(torch.zeros(shape))
# 非模型参数的变量初始化为0和1
self.moving_mean = torch.zeros(shape)
self.moving_var = torch.ones(shape) # 初始化为0,1正态分布
def forward(self, X):
# 如果X不在内存上,将moving_mean和moving_var
# 复制到X所在显存上
if self.moving_mean.device != X.device:
self.moving_mean = self.moving_mean.to(X.device)
self.moving_var = self.moving_var.to(X.device)
# 保存更新过的moving_mean和moving_var
Y, self.moving_mean, self.moving_var = batch_norm(
X, self.gamma, self.beta, self.moving_mean,
self.moving_var, eps=1e-5, momentum=0.9)
return Y
'''应用在LeNet上'''
net = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(16 * 4 * 4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),
nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),
nn.Linear(84, 10))
'''简洁实现,使用nn.BatchNorm2d,只需要输入通道数作为参数'''
net2 = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),
nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
nn.Linear(84, 10))
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()