神经网络中的常见标准化(Normalization)

本文主要介绍常见的标准化方法,并通过代码自己去复现官方的API。常见的几种标准化的方法示意图如下图所示,但是这种比较抽象,初学者很难理解。其实文本和图像都可以使用这些标准化方法,如果把图像中的每个像素想象成文本中一个个字符,然后将图像展平拉伸为1维,那么图像就是文本了。如果实在没法理解,不用担心,本文会提供更加生动的图示。
在这里插入图片描述

1. Batch Normalization

Ioffe S, Szegedy C. Batch normalization: Accelerating deep network training by reducing internal covariate shift[C]//International conference on machine learning. PMLR, 2015: 448-456.

在这里插入图片描述
具体做法如上图所示,计算除了通道维的均值和方差从而进行标准化,即batch中每个样本对应位置的特征做标准化(以均值为例,每个样本第n个通道的特征图计算均值,得到一个均值特征图,再对高和宽再算均值最终得到n个1x1的均值特征图)。其中 γ \gamma γ β \beta β是可学习的超参数。BN的提出是为了加快神经网络的训练并且解决Internal Covariate Shift的问题。使用了BN之后我们可以采用较大的学习率。
在这里插入图片描述

此外,BN也可以提供正则化的作用,从而减少Dropout的使用。

pytorch简洁代码实现:

import torch
import torch.nn as nn


class MyBatchNorm2D(nn.Module):

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(MyBatchNorm2D, self).__init__()
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        self.eps = eps
        self.momentum = momentum
        self.register_buffer('running_mean', torch.zeros(num_features))  # 保存均值
        self.register_buffer('running_var', torch.ones(num_features))  # 保存方差

    def forward(self, x):
        # x: (N, C, H, W)
        if self.training:
            mean = x.mean(dim=(0, 2, 3), keepdim=True)  # (1, C, 1, 1) 计算每个通道的均值
            var = ((x - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)  # (1, C, 1, 1) 计算每个通道的方差
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.squeeze()  # [C]
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.squeeze()  # [C]
        else:
            mean = self.running_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3)  # (1, C, 1, 1)
            var = self.running_var.unsqueeze(0).unsqueeze(2).unsqueeze(3)  # (1, C, 1, 1)
        x = (x - mean) / (var + self.eps).sqrt()
        x = x * self.gamma.unsqueeze(0).unsqueeze(2).unsqueeze(3) + self.beta.unsqueeze(0).unsqueeze(2).unsqueeze(3)
        return x

2. Layer Normalization

Ba J L, Kiros J R, Hinton G E. Layer normalization[J]. arXiv preprint arXiv:1607.06450, 2016.

batch size过小的情况下,BN的效果往往不那么理想,并且很难应用于NLP的任务,因为NLP的句子末尾通常会有填充的空白token,因此batch中每个样本在末尾相对应的特征做BN完全没有意义。为了解决这一问题,LN被提出,原理很简单,就是每个样本自己做标准化即可。
在这里插入图片描述
在这里插入图片描述
pytorch简洁代码实现:

class MyLN(nn.Module):

    def __init__(self, normalized_shape):
        super(MyLN, self).__init__()
        self.normalized_shape = normalized_shape
        self.scale = nn.Parameter(torch.ones(normalized_shape))
        self.shift = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = 1e-5

    def forward(self, x):
        if isinstance(self.normalized_shape, list):
            dim = [-(i+1) for i in range(len(self.normalized_shape))]
        else:
            dim = -1
        mean = torch.mean(x, dim=dim, keepdim=True)
        var = torch.mean(x**2, dim=dim, keepdim=True) - mean**2
        x = (x - mean) / (torch.sqrt(var + self.eps))
        x = x * self.scale + self.shift
        return x

3. Instance Normalization

Instance Normalization主要用于图像风格迁移的任务中。它对每个Batch中的数据的每个通道进行单独的标准化。
在这里插入图片描述
pytorch简洁代码实现:

import torch
import torch.nn as nn

class MyInstanceNorm(nn.Module):

    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(MyInstanceNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.gamma = nn.Parameter(torch.Tensor(num_features))
        self.beta = nn.Parameter(torch.Tensor(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.reset_parameters()

    def reset_parameters(self):
        self.running_mean.zero_()
        self.running_var.fill_(1)
        self.gamma.data.fill_(1)
        self.beta.data.zero_()

    def forward(self, x):
        # x: (N, C, H, W)
        if self.training:
            mean = x.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # (N, C, 1, 1)
            var = ((x - mean) ** 2).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)  # (N, C, 1, 1)
            x = (x - mean) / torch.sqrt(var + self.eps)
            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var.mean(dim=2, keepdim=True).mean(dim=3, keepdim=True)
        else:
            x = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        x = x * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)
        return x

4. Group Normalization

即对每个batch数据的通道进行分组归一化,当组数等于通道数时,即等价于Instance normalization,当组数为1时,即等价于Layer normalization。
在这里插入图片描述
pytorch简洁代码实现:

import torch
import torch.nn as nn

class MyGroupNorm(nn.Module):

    def __init__(self, num_groups, num_channels, eps=1e-5):
        super(MyGroupNorm, self).__init__()
        self.num_groups = num_groups
        self.num_channels = num_channels
        self.eps = eps
        self.weight = nn.Parameter(torch.Tensor(num_channels))
        self.bias = nn.Parameter(torch.Tensor(num_channels))
        self.reset_parameters()

    def reset_parameters(self):
        self.weight.data.fill_(1)
        self.bias.data.zero_()

    def forward(self, x):
        # x: (N, C, H, W)
        N, C, H, W = x.size()
        G = self.num_groups
        assert C % G == 0
        x = x.view(N, G, -1)  # (N, G, C//G*H*W)
        mean = x.mean(dim=-1, keepdim=True)  # (N, G, 1)
        var = x.var(dim=-1, keepdim=True)  # (N, G, 1)
        x = (x - mean) / torch.sqrt(var + self.eps)
        x = x.view(N, C, H, W)
        x = x * self.weight.view(1, C, 1, 1) + self.bias.view(1, C, 1, 1)
        return x
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值