定义1
torch.sort(a,dim,descending)
用法1
输入a,在dim维进行排序,descending控制是否降序,默认为False。
输出排序后的值以及对应值在原a中的下标,
示例1
import torch
a = torch.tensor([[10,2,3],[4,6,5],[7,8,9]])
print(a)
>>tensor([[10, 2, 3],
[ 4, 6, 5],
[ 7, 8, 9]])
在dim=0默认升序
torch.sort(a,0)
>>torch.return_types.sort(
values=tensor([[ 4, 2, 3],
[ 7, 6, 5],
[10, 8, 9]]),
indices=tensor([[1, 0, 0],
[2, 1, 1],
[0, 2, 2]]))
在dim=1降序
torch.sort(a,1,descending=True)
>>torch.return_types.sort(
values=tensor([[10, 3, 2],
[ 6, 5, 4],
[ 9, 8, 7]]),
indices=tensor([[0, 2, 1],
[1, 2, 0],
[2, 1, 0]]))
定义2
torch.argsort()
用法2
返回排序后的值所对应原a的下标,即torch.sort()返回的indices
示例2
将输入a在 dim=0降序排列
torch.argsort(a,0,descending=True)
>>tensor([[0, 2, 2],
[2, 1, 1],
[1, 0, 0]])