DETR的segm分割头的FPN部分就是用的GN
这是何恺明团队2018年提出来的
GN优化了BN在比较小的mini-batch情况下表现不太好的劣势。
Group Normalization(GN) 提出的一种 BN 的替代方法,其是首先将 Channels 划分为多个 groups,再计算每个 group 内的均值和方差,以进行归一化。 GN的计算与Batch Size无关,因此对于高精度图片小BatchSize的情况也是非常稳定的,
下图是比较BN和GN在Batch Size越来越小的变化中,模型错误率变化的对比图:
torch.nn.GroupNorm(num_groups, num_channels, eps=1e-05, affine=True, device=None, dtype=None)
import torch from torch import nn input = torch.randn(20, 6, 10, 10) # Separate 6 channels into 3 groups m = nn.GroupNorm(3, 6) # Separate 6 channels into 6 groups (equivalent with InstanceNorm) m = nn.GroupNorm(6, 6) # Put all 6 channels into a single group (equivalent with LayerNorm) m = nn.GroupNorm(1, 6) # Activating the module output = m(input)
手动实现
其实就是先split再concat起来
import torch import torch.nn as nn gn_layer = nn.GroupNorm(num_groups=3, num_channels=6) input = torch.randn(8, 6, 20, 20) gn_outputs = gn_layer(input) group_inputs = torch.split(input, 2, dim=1) result = [] for group_input in group_inputs: mean = torch.mean(group_input, dim=[1,2,3], keepdim=True) var = torch.var(group_input, dim=[1,2,3], keepdim=True, unbiased=False) output = (group_input-mean) / (torch.sqrt(var) + gn_layer.eps) result.append(output) result = torch.cat(result, dim=1) assert torch.allclose(gn_outputs, result)