class GroupNorm(nn.Module):
def __init__(self, G, dim, eps=1e-5):
super(GroupNorm, self).__init__()
self.G = G
self.eps = eps
self.gamma = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
N, C, H, W = x.shape
x = x.reshape(N, self.G, C // self.G, H, W)
mean = torch.mean(x, dim=[2, 3, 4], keepdims=True)
var = torch.var(x, dim=[2, 3, 4], keepdims=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
return x.reshape(N, C, H, W)
if __name__ == '__main__':
N, C, H, W = 2, 20, 5, 5
embedding = torch.randn(N, C, H, W)
print(embedding.shape) # torch.Size([2, 4, 3])
# exit()
layer_norm = nn.GroupNorm(2, 20)
my_layer_norm = GroupNorm(2, 20)
print(layer_norm(embedding)[0][0])
print(my_layer_norm(embedding)[0][0])
结果和torch里调用函数的结果是一样的