作用:对给定tensor的指定维度进行排序,返回排序后的结果和排序后的值对应原来维度位置的序号。
举例说明:
import torch
a = torch.randint(2, 10,(6,4)) # 创建shape为6*4,值为[2,10]的随机整数的tensor
b, sort_index = torch.sort(a, dim=1, descending=True) # 对a的第1维度(列)进行降序排序,返回结果和排序后的值对应原来维度位置的序号
print('a:', a)
print('b:', b)
print('sort_index:', sort_index)
''' 运行结果 '''
a: tensor([[8, 5, 7, 8],
[9, 6, 6, 9],
[3, 6, 8, 7],
[2, 8, 9, 9],
[6, 9, 4, 6],
[9, 9, 4, 7]])
b: tensor([[8, 8, 7, 5],
[9, 9, 6, 6],
[8, 7, 6, 3],
[9, 9, 8, 2],
[9, 6, 6, 4],
[9, 9, 7, 4]])
sort_index: tensor([[0, 3, 2, 1],
[0, 3, 1, 2],
[2, 3, 1, 0],
[2, 3, 1, 0],
[1, 0, 3, 2],
[0, 1, 3, 2]])