PyTorch中的批处理规范化
在本集中,我们将看到如何向PyTorch CNN 添加批处理规范化。
什么是批处理规范化
为了理解批量归一化,我们需要先了解什么是一般的数据归一化,我们在数据集归一化的章节中了解了这个概念。
当我们对一个数据集进行归一化时,我们是在对将要传递给网络的输入数据进行归一化,而当我们在网络中加入批量归一化时,我们是在数据通过一层或多层后再次进行归一化。
可能想到的一个问题:如果输入已经标准化,为什么还要再次标准化?
随着数据开始通过层移动,随着层转换的执行,值将开始移动。标准化图层的输出可确保当数据从输入到输出流经网络时,尺度保持在特定范围内。
通常使用的特定归一化技术称为标准化。这就是我们使用平均值和标准差来计算z分数。
批量规范如何工作
当使用批次规范化时,平均数和标准差值是在应用规范化时相对于批次计算的。这与整个数据集相反,就像我们看到的数据集归一化。
此外,还有两个可学习的参数,这些参数允许对数据进行缩放和移位。我们在论文中看到了这一点:批量标准化:通过减少内部协方差来加速深度网络训练
需要注意的是,γ确定的缩放对应的是乘法运算,β确定的sift对应的是加法运算。
Scale和Sift操作听起来很花哨,但它们的意思很简单,就是乘法和加法。
这些可学习的参数让数值的分布有更大的自由度,可以自由移动,调整到合适的位置。
Scale和Sift可以看作是一条线的斜率和y截距值,这两个值可以让这条线调整到适合二维平面上的不同位置。
将批处理规范添加到CNN
好了,让我们创建两个网络,一个有批处理规范,一个没有。然后,我们将使用我们在课程中开发的测试框架来测试这些设置。要做到这一点,我们将利用nn.Sequential
类。
我们的第一个网络将称为network1
:
torch.manual_seed(50)
network1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
, nn.ReLU()
, nn.MaxPool2d(kernel_size=2, stride=2)
, nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
, nn.ReLU()
, nn.MaxPool2d(kernel_size=2, stride=2)
, nn.Flatten(start_dim=1)
, nn.Linear(in_features=12*4*4, out_features=120)
, nn.ReLU()
, nn.Linear(in_features=120, out_features=60)
, nn.ReLU()
, nn.Linear(in_features=60, out_features=10)
)
我们的第二个网络将称为network2
:
torch.manual_seed(50)
network2 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
, nn.ReLU()
, nn.MaxPool2d(kernel_size=2, stride=2)
, nn.BatchNorm2d(6)
, nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
, nn.ReLU()
, nn.MaxPool2d(kernel_size=2, stride=2)
, nn.Flatten(start_dim=1)
, nn.Linear(in_features=12*4*4, out_features=120)
, nn.ReLU()
, nn.BatchNorm1d(120)
, nn.Linear(in_features=120, out_features=60)
, nn.ReLU()
, nn.Linear(in_features=60, out_features=10)
)
现在,我们将创建一个networks
字典,用于存储两个网络。
networks = {
'no_batch_norm': network1
,'batch_norm': network2
}
这个字典的名称或键将在我们的运行循环中用于访问每个网络。为了配置我们的运行,我们可以使用字典中的键,而不是直接写出每个值。这是非常方便,因为它允许我们轻松地测试不同的网络,只需将更多的网络添加到字典中。
params = OrderedDict(
lr = [.01]
, batch_size = [1000]
, num_workers = [1]
, device = ['cuda']
, trainset = ['normal']
, network = list(networks.keys())
)
现在,我们在运行循环内所做的所有事情就是简单地使用运行对象访问网络,该对象使我们能够访问网络字典。就像这样:
network = networks[run.network].to(device)
我们已经准备好进行测试。结果如下:
批量标准化给我们带来了迄今为止最高的精度。
在本集中,我们将看到如何向卷积神经网络添加批处理规范化。
英文原文地址:https://deeplizard.com/learn/video/bCQ2cNhUWQ8