解决自定义batch_norm训练时报错:TypeError: cannot assign ‘torch.cuda.FloatTensor‘ as parameter ‘running_mean‘

        在尝试量化网络时涉及到了自定义bn层,从网上找了开源代码如下:

        我做了一些改动,将running_mean和running_var设置为了Parameter,如果直接赋值为tensor类型变量的话,是不会保存这两个参数的,会造成很多不便,但同时也不会报题目中的错误。

class BatchNorm(nn.Module):
    def __init__(self,num_features,num_dims,w_bit, in_bit, l_shift, out_bit):
        super(BatchNorm,self).__init__()
        if num_dims == 2:
            shape = (1,num_features)
        else:
            shape = (1,num_features,1,1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))

        self.running_mean = nn.Parameter(torch.zeros(shape), requires_grad=False)
        self.running_var = nn.Parameter(torch.ones(shape), requires_grad=False)
        self.eps = 1e-5
        self.w_bit = w_bit
        self.in_bit = in_bit
        self.l_shift = l_shift
        self.out_bit = out_bit

    def forward(self,X):
        # print(self.gamma)
        if self.running_mean.device != X.device:
            self.running_mean = self.running_mean.to(X.device)
            self.running_var = self.running_var.to(X.device)
        Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
        return Y

        在训练的时候会报题目中的错误,报错的根源在于:

Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)

        其原因正如报错中所提到的,从batch_norm()函数中返回的是一个torch.cuda.FloatTensor类型的running_mean,不能直接赋值给Parameter类型,所以我们在此处可以做一下类型转换,即可消除这个报错,我训练下来应该没啥大问题,如果有错请大佬指正!

Y,running_mean_out,running_var_out = batch_norm(self.w_bit, self.in_bit, self.l_shift, self.out_bit, self.training,X,self.gamma,self.beta,self.running_mean,self.running_var,eps=1e-5,momentum=0.9)
self.running_mean = nn.Parameter(running_mean_out,requires_grad=False)
self.running_var = nn.Parameter(running_var_out,requires_grad=False)

        之前网上提到.cuda()解决的办法,但经过实际测试,添加之后无法保存参数,于是放弃了这种操作。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值