PyTorch框架学习十七——Batch Normalization

本次笔记关注深度学习中非常常用的标准化方法:BN。

BN首次于论文《Batch Normalization:Accelerating Deep Network Training by Reducing Internel Covariate》中提出,本来是想要解决训练过程中随着网络层数的加深而导致的数据尺度/分布变化的问题,但是同时还发现了很多其他非常有用的优点,因此自从提出以来就受到了广泛的应用。相应的,LN、IN和GN在基于BN的思想上纷纷被提出,适用于各自不同的场景。

一、BN的概念

Batch Normalization:批归一化,批指的是一小批数据,通常为mini-batch,标准化指的是要将数据标准化为0均值、1方差。

建议看BN的论文,这里提炼一下BN使用的优点:

  1. 可用更大的lr,加速模型收敛。
  2. 可不用精心设计权值初始化,这一点很实用(其实是很省事)
  3. 可不用Dropout或用较小的Dropout
  4. 可不用L2或用较小的weight decay
  5. 可不用LRN(local response normalization)
  6. 以及BN提出的本意,解决了随着网络层数的加深而导致的数据尺度/分布变化的问题

下面给出了BN的算法过程:
在这里插入图片描述
输入是一小批的数据以及要学习的参数γ和β,首先求这一小批数据的均值和方差,然后使用这个均值和方差进行标准化数据,这样得到的数据就服从0均值1标准差。但到这里还没有结束,BN还多了一个仿射变换的步骤,即将标准化后的数据缩放并平移,但是这个是一个可学习的过程,参数γ和β是在训练过程中不断被学习的,如果模型觉得需要,就可以进行仿射变换,这一步的作用是可以增加模型的容量,使得模型更加灵活,选择性更多。

二、Internal Covariate Shift(ICS)

这个就是BN论文本来要解决的问题,训练过程中随着网络层数的增加,数据的分布会随之变化,下面举了一个例子,具体如下图所示:
在这里插入图片描述
这是一个全连接网络,第一层为输入层X,第一个全连接层的权值为W1,则第一个全连接层的输出H1等于X和W1向量相乘,假设输入X满足0均值1标准差,即标准化后的结果。

若W初始化时也是1标准差,那么H1的方差的结果如上图计算得为n,即经过一层全连接层,数据的分布范围就扩大了n倍,那么经过多层,数据的分布将会越来越大,这样反向求梯度的时候也会非常大,也就是梯度爆炸现象。

再试想一下,如果初始化的时候W的方差很小,小于1/n,那么H1的方差的结果将会小于1,那么经过多层,数据的分布将会越来越小,这样反向求梯度的时候也会非常小,也就是梯度消失现象。

以上所述的就是ICS,而BN所做的就是在每一层全连接层后面将数据分布变化的数据再标准化回0均值1标准差,以此来消除对后续网络层的影响,从而消除了ICS。

下面将构造一个100层,每层256个神经元的全连接网络,观察其数据分布随网络层数的变化:

import torch
import numpy as np
import torch.nn as nn
import sys, os
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子

class MLP(nn.Module):
    def __init__(self, neural_num, layers=100):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
        self.bns = nn.ModuleList([nn.BatchNorm1d(neural_num) for i in range(layers)])
        self.neural_num = neural_num

    def forward(self, x):

        for (i, linear), bn in zip(enumerate(self.linears), self.bns):
            x = linear(x)
            # method 3
            # x = bn(x)
            x = torch.relu(x)

            if torch.isnan(x.std()):
                print("output is nan in {} layers".format(i))
                break

            print("layers:{}, std:{}".format(i, x.std().item()))

        return x

    def initialize(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):

                # method 1
                nn.init.normal_(m.weight.data, std=1)    # normal: mean=0, std=1

                # method 2 kaiming
                # nn.init.kaiming_normal_(m.weight.data)


neural_nums = 256
layer_nums = 100
batch_size = 16

net = MLP(neural_nums, layer_nums)
net.initialize()

inputs = torch.randn((batch_size, neural_nums))  # normal: mean=0, std=1

output = net(inputs)
print(output)

代码中有3种情况:

  • 第一个是使用正态分布给权值初始化,使得权值0均值1标准差,且没有使用BN层,将每一层网络的数据标准差打印出来如下所示:
layers:0, std:9.352246284484863
layers:1, std:112.47123718261719
layers:2, std:1322.8056640625
layers:3, std:14569.42578125
layers:4, std:154672.765625
layers:5, std:1834038.125
layers:6, std:18807982.0
layers:7, std:209553056.0
layers:8, std:2637502976.0
layers:9, std:32415457280.0
layers:10, std:374825549824.0
layers:11, std:3912853094400.0
layers:12, std:41235926482944.0
layers:13, std:479620541448192.0
layers:14, std:5320927071961088.0
layers:15, std:5.781225696395264e+16
layers:16, std:7.022147146707108e+17
layers:17, std:6.994718592201654e+18
layers:18, std:8.473501588274335e+19
layers:19, std:9.339794309346954e+20
layers:20, std:9.56936220412742e+21
layers:21, std:1.176274650258599e+23
layers:22, std:1.482641634599281e+24
layers:23, std:1.6921343606923352e+25
layers:24, std:1.9741450942615745e+26
layers:25, std:2.1257213262324592e+27
layers:26, std:2.191710730990783e+28
layers:27, std:2.5254503817521246e+29
layers:28, std:3.221308876879874e+30
layers:29, std:3.530952437322462e+31
layers:30, std:4.525353644890983e+32
layers:31, std:4.715011552268428e+33
layers:32, std:5.369590669553154e+34
layers:33, std:6.712318470791119e+35
layers:34, std:7.451114589527308e+36
output is nan in 35 layers
tensor([[3.2626e+36, 0.0000e+00, 7.2932e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.5465e+38],
        [3.9237e+36, 0.0000e+00, 7.5033e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.1274e+38],
        [0.0000e+00, 0.0000e+00, 4.4932e+37,  ..., 0.0000e+00, 0.0000e+00,
         1.7016e+38],
        ...,
        [0.0000e+00, 0.0000e+00, 2.4222e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.5295e+38],
        [4.7380e+37, 0.0000e+00, 2.1580e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.6028e+38],
        [0.0000e+00, 0.0000e+00, 6.0878e+37,  ..., 0.0000e+00, 0.0000e+00,
         2.1695e+38]], grad_fn=<ReluBackward0>)

正如上面分析的一样,数据变得越来越大。

  • 第二种情况是找到一种合适的初始化方法,使得其没有ICS,这里使用了kaiming_normal初始化,结果如下:
layers:0, std:0.8266295790672302
layers:1, std:0.8786815404891968
layers:2, std:0.9134421944618225
layers:3, std:0.8892470598220825
layers:4, std:0.8344280123710632
layers:5, std:0.874537467956543
layers:6, std:0.7926970720291138
layers:7, std:0.7806458473205566
layers:8, std:0.8684563636779785
layers:9, std:0.9434137344360352
layers:10, std:0.964215874671936
layers:11, std:0.8896796107292175
layers:12, std:0.8287257552146912
layers:13, std:0.8519770503044128
layers:14, std:0.83543461561203
layers:15, std:0.802306056022644
layers:16, std:0.8613607287406921
layers:17, std:0.7583686709403992
layers:18, std:0.8120225071907043
layers:19, std:0.791111171245575
layers:20, std:0.7164373397827148
layers:21, std:0.778393030166626
layers:22, std:0.8672043085098267
layers:23, std:0.8748127222061157
layers:24, std:0.9020991921424866
layers:25, std:0.8585717082023621
layers:26, std:0.7824354767799377
layers:27, std:0.7968913912773132
layers:28, std:0.8984370231628418
layers:29, std:0.8704466819763184
layers:30, std:0.9860475063323975
layers:31, std:0.9080778360366821
layers:32, std:0.9140638113021851
layers:33, std:1.0099570751190186
layers:34, std:0.9909381866455078
layers:35, std:1.0253210067749023
layers:36, std:0.8490436673164368
layers:37, std:0.703953742980957
layers:38, std:0.7186156511306763
layers:39, std:0.7250635623931885
layers:40, std:0.7030817866325378
layers:41, std:0.6325559616088867
layers:42, std:0.6623691916465759
layers:43, std:0.6960877180099487
layers:44, std:0.7140734195709229
layers:45, std:0.6329052448272705
layers:46, std:0.645889937877655
layers:47, std:0.7354376912117004
layers:48, std:0.6710689067840576
layers:49, std:0.6939154863357544
layers:50, std:0.6889259219169617
layers:51, std:0.6331775188446045
layers:52, std:0.6029314398765564
layers:53, std:0.6145529747009277
layers:54, std:0.6636687517166138
layers:55, std:0.7440096139907837
layers:56, std:0.7972176671028137
layers:57, std:0.7606151103973389
layers:58, std:0.6968684196472168
layers:59, std:0.7306802868843079
layers:60, std:0.6875628232955933
layers:61, std:0.7171440720558167
layers:62, std:0.7646605968475342
layers:63, std:0.7965087294578552
layers:64, std:0.8833741545677185
layers:65, std:0.8592953681945801
layers:66, std:0.8092937469482422
layers:67, std:0.8064812421798706
layers:68, std:0.6792411208152771
layers:69, std:0.6583347320556641
layers:70, std:0.5702279210090637
layers:71, std:0.5084437727928162
layers:72, std:0.4869327247142792
layers:73, std:0.4635041356086731
layers:74, std:0.4796812832355499
layers:75, std:0.4737212061882019
layers:76, std:0.4541455805301666
layers:77, std:0.4971913695335388
layers:78, std:0.49279505014419556
layers:79, std:0.44223514199256897
layers:80, std:0.4802999496459961
layers:81, std:0.5579249858856201
layers:82, std:0.5283756852149963
layers:83, std:0.5451982617378235
layers:84, std:0.6203728318214417
layers:85, std:0.6571894884109497
layers:86, std:0.7036821842193604
layers:87, std:0.7321069836616516
layers:88, std:0.6924358606338501
layers:89, std:0.6652534604072571
layers:90, std:0.6728310585021973
layers:91, std:0.6606624126434326
layers:92, std:0.6094606518745422
layers:93, std:0.6019104719161987
layers:94, std:0.5954217314720154
layers:95, std:0.6624558568000793
layers:96, std:0.6377887725830078
layers:97, std:0.6079288125038147
layers:98, std:0.6579317450523376
layers:99, std:0.6668478846549988
tensor([[0.0000, 1.3437, 0.0000,  ..., 0.0000, 0.6444, 1.1867],
        [0.0000, 0.9757, 0.0000,  ..., 0.0000, 0.4645, 0.8594],
        [0.0000, 1.0023, 0.0000,  ..., 0.0000, 0.5148, 0.9196],
        ...,
        [0.0000, 1.2873, 0.0000,  ..., 0.0000, 0.6454, 1.1411],
        [0.0000, 1.3589, 0.0000,  ..., 0.0000, 0.6749, 1.2438],
        [0.0000, 1.1807, 0.0000,  ..., 0.0000, 0.5668, 1.0600]],
       grad_fn=<ReluBackward0>)

数据的确没有随着网络层加深而快速增大或减小,但是寻找到一种合适的初始化方法往往很花费时间。

  • 第三种情况是加入BN层,不使用初始化:
layers:0, std:0.5751240849494934
layers:1, std:0.5803307890892029
layers:2, std:0.5825020670890808
layers:3, std:0.5823132395744324
layers:4, std:0.5860626101493835
layers:5, std:0.579832911491394
layers:6, std:0.5815905332565308
layers:7, std:0.5734466910362244
layers:8, std:0.5853903293609619
layers:9, std:0.5811620950698853
layers:10, std:0.5818504095077515
layers:11, std:0.5775734186172485
layers:12, std:0.5788553357124329
layers:13, std:0.5831498503684998
layers:14, std:0.5726235508918762
layers:15, std:0.5717664957046509
layers:16, std:0.576700747013092
layers:17, std:0.5848639607429504
layers:18, std:0.5718148350715637
layers:19, std:0.5775086879730225
layers:20, std:0.5790560841560364
layers:21, std:0.5815289616584778
layers:22, std:0.5845211744308472
layers:23, std:0.5830678343772888
layers:24, std:0.5817515850067139
layers:25, std:0.5793628096580505
layers:26, std:0.5744576454162598
layers:27, std:0.581753134727478
layers:28, std:0.5858433246612549
layers:29, std:0.5895737409591675
layers:30, std:0.5806193351745605
layers:31, std:0.5742025971412659
layers:32, std:0.5814924240112305
layers:33, std:0.5800969004631042
layers:34, std:0.5751299858093262
layers:35, std:0.5819362998008728
layers:36, std:0.57569420337677
layers:37, std:0.5824175477027893
layers:38, std:0.5741908550262451
layers:39, std:0.5768386721611023
layers:40, std:0.578640341758728
layers:41, std:0.5833579301834106
layers:42, std:0.5873513221740723
layers:43, std:0.5807022452354431
layers:44, std:0.5743744373321533
layers:45, std:0.5791332721710205
layers:46, std:0.5789337158203125
layers:47, std:0.5805914402008057
layers:48, std:0.5796007513999939
layers:49, std:0.5833531022071838
layers:50, std:0.5896912813186646
layers:51, std:0.5851364731788635
layers:52, std:0.5816906094551086
layers:53, std:0.5805508494377136
layers:54, std:0.5876169204711914
layers:55, std:0.576688826084137
layers:56, std:0.5784814357757568
layers:57, std:0.5820549726486206
layers:58, std:0.5837342739105225
layers:59, std:0.5691872835159302
layers:60, std:0.5777156949043274
layers:61, std:0.5763663649559021
layers:62, std:0.5843147039413452
layers:63, std:0.5852570533752441
layers:64, std:0.5836994051933289
layers:65, std:0.5794276595115662
layers:66, std:0.590632438659668
layers:67, std:0.5765355825424194
layers:68, std:0.5794717073440552
layers:69, std:0.5696660876274109
layers:70, std:0.5910594463348389
layers:71, std:0.5822493433952332
layers:72, std:0.5893915295600891
layers:73, std:0.5875967741012573
layers:74, std:0.5845006108283997
layers:75, std:0.573967695236206
layers:76, std:0.5823272466659546
layers:77, std:0.5769740343093872
layers:78, std:0.5787169933319092
layers:79, std:0.5757712721824646
layers:80, std:0.5799717307090759
layers:81, std:0.577584981918335
layers:82, std:0.581005334854126
layers:83, std:0.5819255113601685
layers:84, std:0.577966570854187
layers:85, std:0.5941665172576904
layers:86, std:0.5822250247001648
layers:87, std:0.5828983187675476
layers:88, std:0.5758668184280396
layers:89, std:0.5786070823669434
layers:90, std:0.5724494457244873
layers:91, std:0.5775058269500732
layers:92, std:0.5749661326408386
layers:93, std:0.5795350670814514
layers:94, std:0.5690663456916809
layers:95, std:0.5838885307312012
layers:96, std:0.578350305557251
layers:97, std:0.5750819444656372
layers:98, std:0.5843801498413086
layers:99, std:0.5825926065444946
tensor([[1.0858, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.7552],
        [0.0000, 0.0000, 0.3216,  ..., 0.0000, 0.0000, 0.1931],
        [0.0000, 0.5979, 0.1423,  ..., 1.2776, 0.9048, 0.0000],
        ...,
        [0.8705, 0.4248, 0.0000,  ..., 0.0000, 0.8963, 0.3446],
        [0.0000, 0.0000, 0.0000,  ..., 0.5631, 0.0000, 0.4281],
        [1.1301, 0.0000, 0.0000,  ..., 2.2642, 0.3234, 0.0000]],
       grad_fn=<ReluBackward0>)

可以看得出来效果比使用kaiming_normal初始化更好,这个例子说明BN层的使用可以不用初始化而且避免了ICS的问题。

三、BN的一个应用案例

用LeNet解决一个人民币二分类的问题:

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)

from model.lenet import LeNet, LeNet_bn
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed


class LeNet_bn(nn.Module):
    def __init__(self, classes):
        super(LeNet_bn, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.bn1 = nn.BatchNorm2d(num_features=6)

        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn2 = nn.BatchNorm2d(num_features=16)

        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.bn3 = nn.BatchNorm1d(num_features=120)

        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = self.conv2(out)
        out = self.bn2(out)
        out = F.relu(out)

        out = F.max_pool2d(out, 2)

        out = out.view(out.size(0), -1)

        out = self.fc1(out)
        out = self.bn3(out)
        out = F.relu(out)

        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

    def initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight.data, 0, 1)
                m.bias.data.zero_()


set_seed(1)  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.8),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

# net = LeNet_bn(classes=2)
net = LeNet(classes=2)
# net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

iter_count = 0
# 构建 SummaryWriter
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        iter_count += 1

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

        # 记录数据,保存于event file
        writer.add_scalars("Loss", {"Train": loss.item()}, iter_count)
        writer.add_scalars("Accuracy", {"Train": correct / total}, iter_count)

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss.item())
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))

            # 记录数据,保存于event file
            writer.add_scalars("Loss", {"Valid": loss.item()}, iter_count)
            writer.add_scalars("Accuracy", {"Valid": correct / total}, iter_count)

train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

首先看一下不使用BN层不使用初始化只用了LeNet的结果:
在这里插入图片描述
后期数据出现不理想的情况会使得训练损失发生较大震荡。

再看一下加上初始化后的结果:
在这里插入图片描述
后期不再震荡,但是前期也不是很理想。

最后看使用加了BN层的LeNet的结果:
在这里插入图片描述
相对来说,这是比较理想的损失函数曲线,尽管有震荡,但是幅度很小。

四、PyTorch中BN的实现

1._BatchNorm类

class _BatchNorm(_NormBase):

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_BatchNorm, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        # exponential_average_factor is set to self.momentum
        # (when it is available) only so that it gets updated
        # in ONNX graph when this node is exported to ONNX.
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum

        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked = self.num_batches_tracked + 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            self.training or not self.track_running_stats,
            exponential_average_factor, self.eps)

这是BN层需要继承的基类,从init函数可以看出主要参数如下:

  1. num_features:一个样本的特征数量。
  2. eps:分母修正项,防止标准化时除以一个为0的方差使得出错。
  3. momentum:指数加权平均估计当前mean/var。
  4. affine:是否需要affine transform。
  5. track_running_stats:是否为训练状态,因为训练模式的mean/var是基于当前数据指数加权平均计算得到的,每个batch都不一样,而测试模式时的mean/var是统计得到,是固定的,不随batch而变化。

此外,从forward函数中可以看到实现标准化和仿射变换操作的,是最后的return,它调用了PyTorch的functional功能,在这个调用里主要属性有:

  1. running_mean:均值。
  2. running_var:方差。
  3. weight:仿射变换中的gamma。
  4. bias:仿射变换中的beta。

对应下面这个公式里的四个参数:
在这里插入图片描述
训练时均值和方差是采用指数加权平均计算得到的,公式如下所示:
在这里插入图片描述
其中,pre_running_mean是上一个batch计算得到的均值,mean_t是当前batch下求取的均值,最终的均值是这样的一个指数加权平均的形式,方差也是同理。

2.nn.BatchNorm1d/2d/3d

(1)nn.BatchNorm1d

与其他网络层类似,BN层也有三种维度,以一维的为例:

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

参数都是基类_BatchNorm的参数。

上面介绍了标准化的计算公式,但是BN并不是简单的将这一个batch的数据全部放在一起求均值和方差,而是一种逐特征(在特征这个维度上)的计算方式,一维数据如下图所示:
在这里插入图片描述
每一列代表一个数据样本,一共有三个,每个数据样本有5个特征(类似RGB模式),每个特征下就是一个特征的维度,在一维的情况下就是(1),所以数据维度是一个3×5×1的形式,而BN所做的是将三个样本的第一个特征进行标准化和仿射变换,即三个1这行计算均值和方差,同理,另外四行都是如此,因为是对每个特征进行单独的计算均值和方差,所以称为逐特征的计算方式。

下面给出PyTorch的实现代码:

import torch
import numpy as np
import torch.nn as nn
import sys, os
from tools.common_tools import set_seed

set_seed(1)  # 设置随机种子

# ======================================== nn.BatchNorm1d
flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 5
    momentum = 0.3

    features_shape = (1)
    
    # 下面三行就是手动构建一个B×C×features_shape :3×5×1 的上述数据
    feature_map = torch.ones(features_shape)                                                    # 1D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 2D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 3D
    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm1d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niteration:{}, running mean: {} ".format(i, bn.running_mean))
        print("iteration:{}, running var:{} ".format(i, bn.running_var))
		
		# 以下是手动计算的验证过程
        mean_t, var_t = 2, 0

        running_mean = (1 - momentum) * running_mean + momentum * mean_t
        running_var = (1 - momentum) * running_var + momentum * var_t

        print("iteration:{}, 第二个特征的running mean: {} ".format(i, running_mean))
        print("iteration:{}, 第二个特征的running var:{}".format(i, running_var))

结果如下:

input data:
tensor([[[1.],
         [2.],
         [3.],
         [4.],
         [5.]],

        [[1.],
         [2.],
         [3.],
         [4.],
         [5.]],

        [[1.],
         [2.],
         [3.],
         [4.],
         [5.]]]) shape is torch.Size([3, 5, 1])

iteration:0, running mean: tensor([0.3000, 0.6000, 0.9000, 1.2000, 1.5000]) 
iteration:0, running var:tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000]) 
iteration:0, 第二个特征的running mean: 0.6 
iteration:0, 第二个特征的running var:0.7

iteration:1, running mean: tensor([0.5100, 1.0200, 1.5300, 2.0400, 2.5500]) 
iteration:1, running var:tensor([0.4900, 0.4900, 0.4900, 0.4900, 0.4900]) 
iteration:1, 第二个特征的running mean: 1.02 
iteration:1, 第二个特征的running var:0.48999999999999994

一开始均值和方差为0和1,momentum为0.3,第一个特征的当前均值为1方差为0,可以自己动手算一下,第一个特征的running_mean=0.7×0+0.3×1=0.3;第一个特征的running_var=0.7×1+0.3×0=0.7;剩下的计算类似,动手算一下加深理解。

(2)nn.BatchNorm2d

再来看一下二维的情况,其实和一维类似,只是特征本身的维度变成了二维:
在这里插入图片描述
横轴还是样本数,纵轴还是特征数,还是将三个样本的第一个特征算一下均值方差,第二个第三个特征类似,还是这样一行一行的算,所以这时的输入数据维度会变成四维的:B×C×W×H,B是一个batch样本数,C为特征数,W和H为特征维度,如图中应该是:3×3×2×2,下面看一下实现,为区别B和C,代码中B为3,即三个样本,C为6,即六个特征:

flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 6
    momentum = 0.3
    
    features_shape = (2, 2)

    feature_map = torch.ones(features_shape)                                                    # 2D
    feature_maps = torch.stack([feature_map*(i+1) for i in range(num_features)], dim=0)         # 3D
    feature_maps_bs = torch.stack([feature_maps for i in range(batch_size)], dim=0)             # 4D

    print("input data:\n{} shape is {}".format(feature_maps_bs, feature_maps_bs.shape))

    bn = nn.BatchNorm2d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps_bs)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

结果如下:

input data:
tensor([[[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]],


        [[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]],


        [[[1., 1.],
          [1., 1.]],

         [[2., 2.],
          [2., 2.]],

         [[3., 3.],
          [3., 3.]],

         [[4., 4.],
          [4., 4.]],

         [[5., 5.],
          [5., 5.]],

         [[6., 6.],
          [6., 6.]]]]) shape is torch.Size([3, 6, 2, 2])

iter:0, running_mean.shape: torch.Size([6])
iter:0, running_var.shape: torch.Size([6])
iter:0, weight.shape: torch.Size([6])
iter:0, bias.shape: torch.Size([6])

iter:1, running_mean.shape: torch.Size([6])
iter:1, running_var.shape: torch.Size([6])
iter:1, weight.shape: torch.Size([6])
iter:1, bias.shape: torch.Size([6])

观察一下四个参数的size为6,因为是逐特征的计算。

(3)nn.BatchNorm3d

在这里插入图片描述
只是特征维度变成三维,计算方式还是逐特征的,看一下实现,注意这里特征数又变成了4:

flag = 1
# flag = 0
if flag:

    batch_size = 3
    num_features = 4
    momentum = 0.3

    features_shape = (2, 2, 3)

    feature = torch.ones(features_shape)                                                # 3D
    feature_map = torch.stack([feature * (i + 1) for i in range(num_features)], dim=0)  # 4D
    feature_maps = torch.stack([feature_map for i in range(batch_size)], dim=0)         # 5D

    print("input data:\n{} shape is {}".format(feature_maps, feature_maps.shape))

    bn = nn.BatchNorm3d(num_features=num_features, momentum=momentum)

    running_mean, running_var = 0, 1

    for i in range(2):
        outputs = bn(feature_maps)

        print("\niter:{}, running_mean.shape: {}".format(i, bn.running_mean.shape))
        print("iter:{}, running_var.shape: {}".format(i, bn.running_var.shape))

        print("iter:{}, weight.shape: {}".format(i, bn.weight.shape))
        print("iter:{}, bias.shape: {}".format(i, bn.bias.shape))

结果如下:

input data:
tensor([[[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]],



        [[[[1., 1., 1.],
           [1., 1., 1.]],

          [[1., 1., 1.],
           [1., 1., 1.]]],


         [[[2., 2., 2.],
           [2., 2., 2.]],

          [[2., 2., 2.],
           [2., 2., 2.]]],


         [[[3., 3., 3.],
           [3., 3., 3.]],

          [[3., 3., 3.],
           [3., 3., 3.]]],


         [[[4., 4., 4.],
           [4., 4., 4.]],

          [[4., 4., 4.],
           [4., 4., 4.]]]]]) shape is torch.Size([3, 4, 2, 2, 3])

iter:0, running_mean.shape: torch.Size([4])
iter:0, running_var.shape: torch.Size([4])
iter:0, weight.shape: torch.Size([4])
iter:0, bias.shape: torch.Size([4])

iter:1, running_mean.shape: torch.Size([4])
iter:1, running_var.shape: torch.Size([4])
iter:1, weight.shape: torch.Size([4])
iter:1, bias.shape: torch.Size([4])
  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值