import torch
import torch.nn as nn
class GroupNorm(nn.Module):
def __init__(self, G, N):
super(GroupNorm, self).__init__()
self.G = G
self.bn = nn.BatchNorm2d(N * self.G, track_running_stats=False) # 批量归一化层
def forward(self, x):
N, C, H, W = x.shape
x = x.reshape(1, N * self.G, C // self.G, H * W)
x = self.bn(x)
return x.reshape(N, C, H, W)
if __name__ == '__main__':
N, C, H, W = 16, 2, 5, 5
embedding = torch.randn(N, C, H, W)
layer_norm = nn.GroupNorm(2, C)
my_layer_norm = GroupNorm(2, N)
print(layer_norm(embedding))
print(my_layer_norm(embedding))
基于自带的批归一化层实现组归一化层