
调用库API实现以及手写实现
import torch
import torch.nn as nn
batch_szie=2 #sample
time_steps=3
embedding_dim=4 #channel
num_group=2
inputx=torch.randn(batch_szie,time_steps,embedding_dim)# N*L*C
## 1. 批归一化 实现batch_norm并验证API ---- per channel
# NLP: [N,L,C] -> [C]
# CV: [N,C,H,W] -> [C]
batch_norm_op = torch.nn.BatchNorm1d(embedding_dim, affine=False)
bn_y = batch_norm_op(inputx.transpose(-1,-2)).transpose(-1,-2)
#手写batch_norm
bn_mean=inputx.mean(dim=(0,1),keepdim=True) #.unsqueeze(0).unsqueeze(0).repeat(batch_szie,time_steps,1) #C扩维成 N L C
bn_std=inputx.std(dim=(0,1),unbiased=False,keepdim=True) #.unsqueeze(0).unsqueeze(0).repeat(batch_szie,time_