import torch
l = torch.rand((1,3,5,5))*5
a, idx1 = torch.sort(l,-1)
b, idx2 = torch.sort(idx1,-1)
print(l)
print(a,idx1)
print(b,idx2)
print(l.scatter(-1,idx1,a))
torch.scatter——一个可以与torch.topk连用的赋值函数
最新推荐文章于 2022-07-25 11:34:04 发布