在尝试量化网络时涉及到了自定义bn层,从网上找了开源代码如下:
我做了一些改动,将running_mean和running_var设置为了Parameter,如果直接赋值为tensor类型变量的话,是不会保存这两个参数的,会造成很多不便,但同时也不会报题目中的错误。
class BatchNorm(nn.Module):
def __init__(self,num_features,num_dims,w_bit, in_bit, l_shift, out_bit):
super(BatchNorm,self).__init__()
if num_dims == 2:
shape = (1,num_features)
else:
shape = (1,num_features,1,1)
self.gamma = nn.Parameter(torch.ones(shape))
self.beta = nn.Parameter(torch.zeros(shape))
self.running_mean = nn.Parameter(torch.zeros(shape), requires_grad=False)
self.running_var = nn.Parameter(torch.ones(shape), requires_grad=False)
self.eps = 1e-5
self.w_bit = w_bit
self.in_bit = in_bit
self.l_shift = l_shift
self.out_bit = out_bit
def forward(self,X):
# print(self.gamma)
if self.running_mean.device != X.device:
self.running_mean = self.running_mean.to(X.device)
self.running_var = self.running_var.to(X.device)
Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
return Y
在训练的时候会报题目中的错误,报错的根源在于:
Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
其原因正如报错中所提到的,从batch_norm()函数中返回的是一个torch.cuda.FloatTensor类型的running_mean,不能直接赋值给Parameter类型,所以我们在此处可以做一下类型转换,即可消除这个报错,我训练下来应该没啥大问题,如果有错请大佬指正!
Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
self.running_mean = nn.Parameter(running_mean_out,requires_grad=False)
self.running_var = nn.Parameter(running_var_out,requires_grad=False)
之前网上提到.cuda()解决的办法,但经过实际测试,添加之后无法保存参数,于是放弃了这种操作。