torch.argmax是不会向后传梯度,但是被选中的部分还是可以传梯度的
import torch
s=torch.rand(1,3,6,6,requires_grad=True)
d=torch.rand(1,3,6,6,requires_grad=True)
p=torch.argmax(s,dim=1).unsqueeze(1)
q=torch.gather(d,dim=1,index=p)
q=q.sum()
loss=(q-1)*(q-1)
loss.backward()
print(s.grad)
print(d.grad)
output:
None
tensor([[[[ 0.0000, 0.0000, 0.0000, 38.2169, 0.0000, 0.0000],
[38.2169, 0.0000, 0.0000, 38.2169, 38.2169, 0.0000],
[ 0.0000, 38.2169, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 38.2169, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 38.2169, 0.0000, 38.2169, 0.0000, 0.0000],
[38.2169, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],
[[38.2169, 0.0000, 38.2169, 0.0000, 38.2169, 0.0000],
[ 0.0000, 38.2169, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 38.2169, 38.2169, 0.0000, 38.2169],
[38.2169, 0.0000, 0.0000, 38.2169, 38.2169, 38.2169],
[ 0.0000, 0.0000, 38.2169, 0.0000, 0.0000, 38.2169],
[ 0.0000, 0.0000, 38.2169, 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 38.2169, 0.0000, 0.0000, 0.0000, 38.2169],
[ 0.0000, 0.0000, 38.2169, 0.0000, 0.0000, 38.2169],
[38.2169, 0.0000, 0.0000, 0.0000, 38.2169, 0.0000],
[ 0.0000, 0.0000, 38.2169, 0.0000, 0.0000, 0.0000],
[38.2169, 0.0000, 0.0000, 0.0000, 38.2169, 0.0000],
[ 0.0000, 38.2169, 0.0000, 38.2169, 38.2169, 38.2169]]]])