''' 全连接层中的batch normalization ''' import torch import torch.nn as nn import copy class Net(nn.Module): def __init__(self,dim,pretrained): super(Net,self).__init__() self.bn=nn.BatchNorm1d(dim,1) if pretrained: self.pretrained() def forward(self, input): return self.bn(input) def pretrained(self): nn.init.constant_(self.bn.weight, 1) nn.init.constant_(self.bn.bias, 0) def bn_train_fc(input,model): state_dict=model.state_dict() print(state_dict) # weights=state_dict.items()[0].view(1,-1).expand(input.shape) weights=torch.tensor([]) bias=torch.tensor([]) for k,v in state_dict.items(): if k=='bn.weight': weights=v if k=='bn.bias': bias=v weights=weights.view(1,-1).expand(input.shape) bias=bias.view(1,-1).expand(input.shape) run_mean=torch.mean(input,dim=0).view(1,-1).expand(input.shape) run_var=torch.var(input,dim=0).view(1,-1).expand(input.shape) output=(input-run_mean)/(run_var+1e-5).sqrt() output=output*weights+bias print(run_mean,run_var) return output def bn_test_fc(input,model): state_dict=model.state_dict() weights = torch.tensor([]) bias = torch.tensor([]) mean=torch.tensor([]) val=torch.tensor([]) for k, v in state_dict.items(): if k == 'bn.weight': weights = v if k == 'bn.bias': bias = v if k=='bn.running_mean': mean=v if k=='bn.running_var': var=v weights=weights.view(1,-1).expand(input.shape) bias=bias.view(1,-1).expand(input.shape) run_mean=mean.view(1,-1).expand(input.shape) run_var=var.view(1,-1).expand(input.shape) output=(input-run_mean)/(run_var+1e-5).sqrt() output=output*weights+bias return output if __name__=='__main__': model=Net(dim=5,pretrained=False) input=torch.randn((3,5)) output=model(input) print(model.state_dict()['bn.running_mean']) print(model.state_dict()['bn.running_var']) ''' 全连接层的输入为 [batch_Size,num_dims] 进行batch normalization时是对每个节点进行BN操作的,也就是batch size个数值求平均值和方差 可以想象成:全连接层的每个节点就是CNN的一个卷积核(体现在CNN的一个通道) 故而全连接层的BN是在每个节点上 over batch size dimension CNN的BN是在每个卷积核上 over batch H W dimension ''' output2=bn_train_fc(input,model) print(output,output2)
FC中的BN(伪代码)
最新推荐文章于 2023-12-17 11:41:32 发布