pytorch 使用torch.scatter_()之后无法计算梯度

11 篇文章 0 订阅

上一篇文章介绍到当对需要计算梯度的矩阵进行修改之后梯度计算出错,解决办法是少用a+=a或a=a+1这种形式的赋值操作,然后如果非得使用这些赋值方式的话可以在进行赋值操作之前使用clone()。这篇文章遇到一个问题当使用torch.scatter_()函数对矩阵进行修改时发现就算使用clone()还是会出现梯度无法计算的错误,Google之后找到的解决方法是在对变量进行clone()之前先使用Variable()函数,例子如下:

from torch.autograd import Variable

tensor_to_modify = Variable(tensor_to_modify).clone()

tensor_to_modify.scatter_(dim, index, src)

参考:

https://discuss.pytorch.org/t/backpropagating-through-scatter-add/16887

https://discuss.pytorch.org/t/does-scatter--support-autograd/8587

  • 3
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
pytorch部分代码如下:train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema) if use_amp: with torch.cuda.amp.autocast(): # 开启混合精度 loss = torch.nan_to_num(criterion_train(output, targets)) # 计算loss scaler.scale(loss).backward() # 梯度放大 torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD) if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks or _global_forward_hooks or _global_forward_pre_hooks): return forward_call(*input, **kwargs) class LDAMLoss(nn.Module): def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30): super(LDAMLoss, self).__init__() m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list)) m_list = m_list * (max_m / np.max(m_list)) m_list = torch.cuda.FloatTensor(m_list) self.m_list = m_list assert s > 0 self.s = s self.weight = weight def forward(self, x, target): index = torch.zeros_like(x, dtype=torch.uint8) index.scatter_(1, target.data.view(-1, 1), 1) index_float = index.type(torch.cuda.FloatTensor) batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0,1)) batch_m = batch_m.view((-1, 1)) x_m = x - batch_m output = torch.where(index, x_m, x) return F.cross_entropy(self.s*output, target, weight=self.weight) 报错:Traceback (most recent call last): File "/home/adminis/hpy/ConvNextV2_Demo/train+ca.py", line 279, in <module> train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema) File "/home/adminis/hpy/ConvNextV2_Demo/train+ca.py", line 46, in train loss = torch.nan_to_num(criterion_train(output, targets)) # 计算loss File "/home/adminis/anaconda3/envs/wln/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/adminis/hpy/ConvNextV2_Demo/models/utils.py", line 621, in forward index.scatter_(1, target.data.view(-1, 1), 1) IndexError: scatter_(): Expected dtype int64 for index.
最新发布
06-10

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值