用法跟上面torch.sort()函数一样,不同的是torch.argsort()返回只是排序后的值所对应原输入input的下标,即torch.sort()返回的indices
dim = 1 表示对每行中的元素进行降序排序,descending=True表示降序排序,输出结果为返回排序后的值所对应原输入input的下标indices
x = torch.randn(3, 4)
indices = torch.argsort(x,dim=1,descending=True)
x,indices
输出结果如下:
(tensor([[-0.6069, -0.9252, -0.9177, 0.6997],
[ 0.3245, -0.0665, 0.4600, 0.0722],
[-1.0662, 2.2669, -0.1171, -0.9208]]),
tensor([[3, 0, 2, 1],
[2, 0, 3, 1],
[1, 2, 3, 0]]))