函数原型1:
torch.argsort(input, dim=- 1, descending=False) → LongTensor
参数:
input:(Tensor)输入张量
dim:(int类型)要排序的维度
descending:(布尔类型),升序还是降序。默认升序。
作用:返回按照指定维度排序后的值对应排序前的下标。
该函数其实是torch.sort()返回的第二个元素,第一个元素是排序后的Tensor。
函数原型2
torch.sort(input, dim=- 1, descending=False, stable=False, *, out=None)
参数:
input ( Tensor ) :输入张量。
dim ( int , optional ) : 要排序的维度
descending ( bool , optional ) : 控制顺序(升序或降序)
stable ( bool , optional ) : 使排序更加稳定,这保证了等价元素的顺序得以保留。
作用:
将输入张量的元素按照给定的维度按值升序排序。返回一个元组。
实例
x = torch.randint(10, size=(4, 3))
print(f'x:{x}')
x1 = torch.argsort(x, dim=-1, descending=True)#降序
print(x1)
values, indices = torch.sort(x, dim=-1, descending=True)
print(values)
print(indices==x1)
输出结果:
x:tensor([[0, 6, 7],
[8, 5, 5],
[3, 4, 9],
[4, 6, 5]])
tensor([[2, 1, 0],
[0, 1, 2],
[2, 1, 0],
[1, 2, 0]])
tensor([[7, 6, 0],
[8, 5, 5],
[9, 4, 3],
[6, 5, 4]])
tensor([[True, True, True],
[True, True, True],
[True, True, True],
[True, True, True]])