nn.InstanceNorm2d和nn.BatchNorm2d比较

18 篇文章 0 订阅
14 篇文章 0 订阅

nn.InstanceNorm2d和nn.BatchNorm2d比较

介绍

nn.InstanceNorm2dnn.BatchNorm2d 都是 PyTorch 中常用的归一化层,用于提高神经网络的训练稳定性和泛化能力。

主要区别

它们之间的主要区别如下:

  1. 归一化对象:

    • nn.InstanceNorm2d:实例归一化,对每个样本(实例)的特征进行归一化。适用于每个样本的特征分布不同的情况,如图像风格转换等任务。
    • nn.BatchNorm2d:批归一化,对整个批次中的样本的特征进行归一化。适用于训练深度神经网络时,加速训练过程、提高模型的泛化能力。
  2. 归一化方式:

    • nn.InstanceNorm2d:对每个样本的每个通道进行归一化,即对每个特征图的每个位置点进行归一化。
    • nn.BatchNorm2d:对每个通道的特征图进行归一化,即对每个特征图的所有位置点进行归一化。
  3. 归一化参数:

    • nn.InstanceNorm2d:没有可训练的参数,只有归一化的均值和方差。
    • nn.BatchNorm2d:有可训练的参数,包括缩放因子(scale)、偏移量(shift)、归一化的均值和方差。
  4. 使用场景:

    • nn.InstanceNorm2d:适用于图像风格转换、图像生成等需要保持每个样本特征独立性的任务。
    • nn.BatchNorm2d:适用于深度神经网络的训练过程,加速训练、提高模型的泛化能力。

需要根据具体任务和网络结构的特点选择合适的归一化层。在一般情况下,nn.BatchNorm2d 是更常用的归一化层。

计算公式

nn.InstanceNorm2dnn.BatchNorm2d 在计算上的公式如下:

对于 nn.InstanceNorm2d,假设输入为 x ∈ R N × C × H × W x \in \mathbb{R}^{N \times C \times H \times W} xRN×C×H×W,其中 N N N 是批次大小, C C C 是通道数, H H H W W W 是特征图的高度和宽度。实例归一化的计算公式如下:

InstanceNorm2d ( x ) n , c , h , w = x n , c , h , w − μ n , c σ n , c 2 + ϵ ⋅ γ c + β c \text{InstanceNorm2d}(x)_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_{n,c}}{\sqrt{\sigma^2_{n,c} + \epsilon}} \cdot \gamma_c + \beta_c InstanceNorm2d(x)n,c,h,w=σn,c2+ϵ xn,c,h,wμn,cγc+βc

其中:

  • x n , c , h , w x_{n,c,h,w} xn,c,h,w 是输入张量 x x x 在第 n n n 个样本、第 c c c 个通道、第 h h h 行、第 w w w 列的元素。
  • μ n , c \mu_{n,c} μn,c 是第 n n n 个样本、第 c c c 个通道的均值,计算公式为 μ n , c = 1 H × W ∑ h = 1 H ∑ w = 1 W x n , c , h , w \mu_{n,c} = \frac{1}{H \times W} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n,c,h,w} μn,c=H×W1h=1Hw=1Wxn,c,h,w
  • σ n , c 2 \sigma^2_{n,c} σn,c2 是第 n n n 个样本、第 c c c 个通道的方差,计算公式为 σ n , c 2 = 1 H × W ∑ h = 1 H ∑ w = 1 W ( x n , c , h , w − μ n , c ) 2 \sigma^2_{n,c} = \frac{1}{H \times W} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{n,c,h,w} - \mu_{n,c})^2 σn,c2=H×W1h=1Hw=1W(xn,c,h,wμn,c)2
  • γ c \gamma_c γc 是归一化的缩放因子(scale),是一个可学习的参数。
  • β c \beta_c βc 是归一化的偏移量(shift),是一个可学习的参数。
  • ϵ \epsilon ϵ 是一个小的常数,用于避免除以零的情况。

对于 nn.BatchNorm2d,假设输入为 x ∈ R N × C × H × W x \in \mathbb{R}^{N \times C \times H \times W} xRN×C×H×W,其中 N N N 是批次大小, C C C 是通道数, H H H W W W 是特征图的高度和宽度。批归一化的计算公式如下:

BatchNorm2d ( x ) n , c , h , w = x n , c , h , w − μ c σ c 2 + ϵ ⋅ γ c + β c \text{BatchNorm2d}(x)_{n,c,h,w} = \frac{x_{n,c,h,w} - \mu_c}{\sqrt{\sigma^2_c + \epsilon}} \cdot \gamma_c + \beta_c BatchNorm2d(x)n,c,h,w=σc2+ϵ xn,c,h,wμcγc+βc

其中:

  • x n , c , h , w x_{n,c,h,w} xn,c,h,w 是输入张量 x x x 在第 n n n 个样本、第 c c c 个通道、第 h h h 行、第 w w w 列的元素。
  • μ c \mu_c μc 是第 c c c 个通道的均值,计算公式为 μ c = 1 N × H × W ∑ n = 1 N ∑ h = 1 H ∑ w = 1 W x n , c , h , w \mu_c = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} x_{n,c,h,w} μc=N×H×W1n=1Nh=1Hw=1Wxn,c,h,w
  • σ c 2 \sigma^2_c σc2 是第 c c c 个通道的方差,计算公式为 σ c 2 = 1 N × H × W ∑ n = 1 N ∑ h = 1 H ∑ w = 1 W ( x n , c , h , w − μ c ) 2 \sigma^2_c = \frac{1}{N \times H \times W} \sum_{n=1}^{N} \sum_{h=1}^{H} \sum_{w=1}^{W} (x_{n,c,h,w} - \mu_c)^2 σc2=N×H×W1n=1Nh=1Hw=1W(xn,c,h,wμc)2
  • γ c \gamma_c γc 是归一化的缩放因子(scale),是一个可学习的参数。
  • β c \beta_c βc 是归一化的偏移量(shift),是一个可学习的参数。
  • ϵ \epsilon ϵ 是一个小的常数,用于避免除以零的情况。
  • 7
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在使用 SNGAN 的时候,使用 InstanceNorm 代替 BatchNorm 确实可以得到更好的效果。因为 BatchNorm 的计算是在一个 batch 内进行的,而 InstanceNorm 是在一个单独的样本内进行的,因此 InstanceNorm 更适合用于生成器和判别器中,因为它们的输入不同。 因此,如果你将 `SNGANGenerator` 和 `SNGANDiscriminator` 中的 `nn.BatchNorm2d` 替换为 `nn.InstanceNorm2d`,会得到更好的结果。以下是代码示例: ```python import torch import torch.nn as nn class SNGANGenerator(nn.Module): def __init__(self, z_dim=100, image_size=64, num_channels=3, num_filters=64): super(SNGANGenerator, self).__init__() self.image_size = image_size self.num_channels = num_channels self.num_filters = num_filters self.z_dim = z_dim self.linear = nn.Linear(z_dim, self.num_filters * 8 * self.image_size // 8 * self.image_size // 8) self.blocks = nn.Sequential( nn.InstanceNorm2d(self.num_filters * 8), nn.Upsample(scale_factor=2), nn.Conv2d(self.num_filters * 8, self.num_filters * 4, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters * 4), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(self.num_filters * 4, self.num_filters * 2, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters * 2), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(self.num_filters * 2, self.num_filters, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters), nn.ReLU(inplace=True), nn.Upsample(scale_factor=2), nn.Conv2d(self.num_filters, self.num_channels, kernel_size=3, stride=1, padding=1), nn.Tanh() ) def forward(self, noise): x = self.linear(noise) x = x.view(-1, self.num_filters * 8, self.image_size // 8, self.image_size // 8) x = self.blocks(x) return x class SNGANDiscriminator(nn.Module): def __init__(self, image_size=64, num_channels=3, num_filters=64): super(SNGANDiscriminator, self).__init__() self.image_size = image_size self.num_channels = num_channels self.num_filters = num_filters self.blocks = nn.Sequential( nn.Conv2d(self.num_channels, self.num_filters, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters), nn.ReLU(inplace=True), nn.Conv2d(self.num_filters, self.num_filters * 2, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters * 2), nn.ReLU(inplace=True), nn.Conv2d(self.num_filters * 2, self.num_filters * 4, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters * 4), nn.ReLU(inplace=True), nn.Conv2d(self.num_filters * 4, self.num_filters * 8, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(self.num_filters * 8), nn.ReLU(inplace=True) ) self.linear = nn.Linear(self.num_filters * 8 * self.image_size // 8 * self.image_size // 8, 1) def forward(self, img): x = self.blocks(img) x = x.view(-1, self.num_filters * 8 * self.image_size // 8 * self.image_size // 8) x = self.linear(x) return x ``` 需要注意的是,如果使用 InstanceNorm,需要保证样本的大小是一致的,否则可能会导致效果变差。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值