【Pytorch】(八)Batch Normalization


炼丹神trick,nb的BN算法总结

torch.nn模块的BN类

pytorch的torch.nn模块中有几个BN类:nn.BatchNorm1d,nn.BatchNorm2d,nn.BatchNorm3d

主要参数有:

  • num_features:特征数
  • eps=1e-05 ϵ \epsilon ϵ,防止分母为0
  • momentum=0.1:均值和方差滑动平局的动量值
  • affine=True: 是否仿射变换
  • track_running_stats=True:是否计算均值和方差的滑动平均。

通常情况下除了num_features其他默认即可。

如果要深究track_running_stats的取值,有以下两种情况:
(1)track_running_stats=True
训练阶段model.train():BN用训练集当前批次的均值和方差计算,并计算均值和方差的滑动平均
测试阶段model.eval():BN用训练阶段得到的均值和方差的滑动平均计算

(2)track_running_stats=False
训练阶段model.train():BN用训练集当前批次的均值和方差计算,不计算均值和方差的滑动平均
测试阶段model.eval():BN用测试集当前批次的均值和方差计算

备注:训练阶段model.eval(),测试阶段model.train()这种错误的设置我们不考虑。

nn.BatchNorm1d

对2D或3D输入(带有可选附加通道尺寸的一小批1D输入)应用批量标准化。可用于全连接层。

Input: (N, C)

Output: (N, C)

import torch
import torch.nn as nn
m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100)
output = m(input_1d)
print(output.size())

输出:

torch.Size([64, 100])

Input:(N, C, L)

Output:(N, C, L)

m = nn.BatchNorm1d(100)
input_1d = torch.randn(64, 100,2)
output = m(input_1d)
print(output.size())

输出:

torch.Size([64, 100, 2])

nn.BatchNorm2d

在4D输入(带有附加通道尺寸的2D输入的小批量)上应用批量标准化,可用于卷积层。

Input: (N, C, H, W)

Output: (N, C, H, W)

m = nn.BatchNorm2d(3)
input_2d = torch.randn(32, 3, 64, 64)
output = m(input_2d)
print(output.size())

输出:

torch.Size([32, 3, 64, 64)

nn.BatchNorm3d

在5D输入上应用批量标准化(带有附加通道尺寸的一小批3D输入)

Input: (N, C, D, H, W)

Output: (N, C, D, H, W)

m = nn.BatchNorm3d(3)
input_3d = torch.randn(64,3,64,64,100)
output = m(input_3d)
print(output.size())

输出:

torch.Size([64,3,64,64,100)

以上BN类需指定特征数。新版本Pytorch的nn.LazyBatchNorm1d,nn.LazyBatchNorm2d,nn.LazyBatchNorm3d,则能从input.size(1)推断出特征数,无需指定。

LeNet-5 + BN

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms


class Flatten(nn.Module):
    '''新版本的pytorch可直接使用nn.Flatten'''
    def forward(self, x):
        return x.flatten(1)


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv_bn_act = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.BatchNorm2d(6),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.BatchNorm1d(120),
            nn.ReLU(True),
            nn.Linear(120, 84),
            nn.BatchNorm1d(84),
            nn.ReLU(True),
            nn.Linear(84, 10)
        )

        self.conv_act = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            Flatten(),
            nn.Linear(16 * 5 * 5, 120),
            nn.ReLU(True),
            nn.Linear(120, 84),
            nn.ReLU(True),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        x = self.conv_bn_act(x)
        return x


def train_loop(dataloader, model, loss_fn, optimizer, device):
    for i, data in enumerate(dataloader, 0):
        # 获取输入
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 计算预测值和损失
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)

        # 反向传播优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i + 1) % 100 == 0:
            print('[Batch%4d] loss: %.3f' % (i + 1, loss.item()))



def test_loop(dataloader, model, device):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
            100 * correct / total))


if __name__ == '__main__':
    # 设备
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)
    # 数据集
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  # 标准化图像数据

    trainset = datasets.CIFAR10(root='./cifar10_data', train=True,
                                download=True, transform=transform)
    # 使用num_workers个子进程进行数据加载
    trainloader = DataLoader(trainset, batch_size=64,
                             shuffle=True, num_workers=2)

    testset = datasets.CIFAR10(root='./cifar10_data', train=False,
                               download=True, transform=transform)
    testloader = DataLoader(testset, batch_size=64,
                            shuffle=False, num_workers=2)
    # 超参数
    lr = 0.01  # 选较大的学习率0.001->0.01
    epochs = 10
    # 模型实例
    model = LeNet5().to(device)
    # 损失函数实例
    loss_fn = nn.CrossEntropyLoss()
    # 优化器实例
    optimizer = optim.Adam(model.parameters(), lr=lr)
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        model.train()
        train_loop(trainloader, model, loss_fn, optimizer, device=device)
        model.eval()
        test_loop(testloader, model, device=device)
    print("Done!")





不使用BN:

Epoch 1
-------------------------------
[Batch 100] loss: 1.904
[Batch 200] loss: 2.112
[Batch 300] loss: 1.721
[Batch 400] loss: 1.797
[Batch 500] loss: 1.863
[Batch 600] loss: 1.757
[Batch 700] loss: 1.891
Accuracy of the network on the 10000 test images: 33 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.673
[Batch 200] loss: 1.708
[Batch 300] loss: 1.686
[Batch 400] loss: 1.736
[Batch 500] loss: 1.548
[Batch 600] loss: 1.646
[Batch 700] loss: 1.800
Accuracy of the network on the 10000 test images: 36 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.754
[Batch 200] loss: 1.568
[Batch 300] loss: 1.581
[Batch 400] loss: 1.609
[Batch 500] loss: 1.700
[Batch 600] loss: 1.845
[Batch 700] loss: 1.626
Accuracy of the network on the 10000 test images: 39 %
Epoch 4
-------------------------------
[Batch 100] loss: 1.699
[Batch 200] loss: 1.585
[Batch 300] loss: 1.840
[Batch 400] loss: 1.688
[Batch 500] loss: 1.412
[Batch 600] loss: 1.569
[Batch 700] loss: 1.587
Accuracy of the network on the 10000 test images: 42 %
Epoch 5
-------------------------------
[Batch 100] loss: 1.727
[Batch 200] loss: 1.425
[Batch 300] loss: 1.699
[Batch 400] loss: 1.471
[Batch 500] loss: 1.702
[Batch 600] loss: 1.374
[Batch 700] loss: 1.497
Accuracy of the network on the 10000 test images: 41 %
Epoch 6
-------------------------------
[Batch 100] loss: 1.365
[Batch 200] loss: 1.664
[Batch 300] loss: 1.528
[Batch 400] loss: 1.444
[Batch 500] loss: 1.623
[Batch 600] loss: 1.382
[Batch 700] loss: 1.896
Accuracy of the network on the 10000 test images: 44 %
Epoch 7
-------------------------------
[Batch 100] loss: 1.783
[Batch 200] loss: 1.728
[Batch 300] loss: 1.500
[Batch 400] loss: 1.522
[Batch 500] loss: 1.400
[Batch 600] loss: 1.552
[Batch 700] loss: 1.482
Accuracy of the network on the 10000 test images: 44 %
Epoch 8
-------------------------------
[Batch 100] loss: 1.572
[Batch 200] loss: 1.088
[Batch 300] loss: 1.555
[Batch 400] loss: 1.380
[Batch 500] loss: 1.774
[Batch 600] loss: 1.589
[Batch 700] loss: 1.500
Accuracy of the network on the 10000 test images: 45 %
Epoch 9
-------------------------------
[Batch 100] loss: 1.411
[Batch 200] loss: 1.696
[Batch 300] loss: 1.494
[Batch 400] loss: 1.454
[Batch 500] loss: 1.401
[Batch 600] loss: 1.552
[Batch 700] loss: 1.766
Accuracy of the network on the 10000 test images: 48 %
Epoch 10
-------------------------------
[Batch 100] loss: 1.431
[Batch 200] loss: 1.309
[Batch 300] loss: 1.555
[Batch 400] loss: 1.436
[Batch 500] loss: 1.485
[Batch 600] loss: 1.440
[Batch 700] loss: 1.373
Accuracy of the network on the 10000 test images: 47 %
Done!

使用BN:

Epoch 1
-------------------------------
[Batch 100] loss: 1.571
[Batch 200] loss: 1.588
[Batch 300] loss: 1.443
[Batch 400] loss: 1.439
[Batch 500] loss: 1.209
[Batch 600] loss: 1.205
[Batch 700] loss: 0.996
Accuracy of the network on the 10000 test images: 55 %
Epoch 2
-------------------------------
[Batch 100] loss: 1.134
[Batch 200] loss: 1.395
[Batch 300] loss: 1.279
[Batch 400] loss: 1.043
[Batch 500] loss: 1.000
[Batch 600] loss: 1.141
[Batch 700] loss: 1.191
Accuracy of the network on the 10000 test images: 59 %
Epoch 3
-------------------------------
[Batch 100] loss: 1.456
[Batch 200] loss: 0.928
[Batch 300] loss: 0.987
[Batch 400] loss: 1.119
[Batch 500] loss: 1.186
[Batch 600] loss: 1.055
[Batch 700] loss: 0.952
Accuracy of the network on the 10000 test images: 62 %
Epoch 4
-------------------------------
[Batch 100] loss: 0.956
[Batch 200] loss: 0.979
[Batch 300] loss: 0.830
[Batch 400] loss: 1.061
[Batch 500] loss: 0.885
[Batch 600] loss: 0.904
[Batch 700] loss: 0.807
Accuracy of the network on the 10000 test images: 61 %
Epoch 5
-------------------------------
[Batch 100] loss: 0.843
[Batch 200] loss: 0.854
[Batch 300] loss: 0.993
[Batch 400] loss: 1.025
[Batch 500] loss: 0.898
[Batch 600] loss: 1.075
[Batch 700] loss: 0.654
Accuracy of the network on the 10000 test images: 63 %
Epoch 6
-------------------------------
[Batch 100] loss: 0.623
[Batch 200] loss: 0.704
[Batch 300] loss: 0.821
[Batch 400] loss: 1.147
[Batch 500] loss: 0.761
[Batch 600] loss: 1.032
[Batch 700] loss: 0.852
Accuracy of the network on the 10000 test images: 64 %
Epoch 7
-------------------------------
[Batch 100] loss: 0.718
[Batch 200] loss: 0.882
[Batch 300] loss: 0.855
[Batch 400] loss: 0.818
[Batch 500] loss: 0.888
[Batch 600] loss: 0.576
[Batch 700] loss: 0.963
Accuracy of the network on the 10000 test images: 65 %
Epoch 8
-------------------------------
[Batch 100] loss: 0.706
[Batch 200] loss: 0.515
[Batch 300] loss: 0.742
[Batch 400] loss: 0.491
[Batch 500] loss: 0.714
[Batch 600] loss: 0.878
[Batch 700] loss: 0.821
Accuracy of the network on the 10000 test images: 66 %
Epoch 9
-------------------------------
[Batch 100] loss: 0.814
[Batch 200] loss: 0.968
[Batch 300] loss: 0.729
[Batch 400] loss: 0.838
[Batch 500] loss: 0.649
[Batch 600] loss: 0.664
[Batch 700] loss: 0.692
Accuracy of the network on the 10000 test images: 67 %
Epoch 10
-------------------------------
[Batch 100] loss: 0.792
[Batch 200] loss: 0.560
[Batch 300] loss: 0.698
[Batch 400] loss: 0.857
[Batch 500] loss: 0.815
[Batch 600] loss: 0.853
[Batch 700] loss: 0.724
Accuracy of the network on the 10000 test images: 66 %
Done!
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

二进制人工智能

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值