最近在修改模型时想要实现:根据一个张量上不同元素的值符合一定条件而选择不同的值
A[A>0]=B[A>0]
#对A上大于0的元素替换为B上对应位置元素
于是使用布尔索引的方式进行赋值,发现出现如下报错:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
是因为布尔索引是属于在位操作,在模型反向传播时无法求得梯度。
参考pytorch学习经验(六)torch.where():根据条件修改张量值 - 百度文库
可以使用torch.where()来解决无法反向传播的问题:
A=torch.where(A>0,B,A)
#对A上大于0的元素替换为B上对应元素,其余不变