两次torch.sort求元素在排序后的位置
- 目的:
利用两次sort求原张量中每个元素第几小(大),测试用例如下
import torch
ts = torch.randn((1, 5), dtype=torch.float)
print(ts)
srt1, idx = ts.sort(dim=1)
print(idx)
srt2, rank = idx.sort(dim=1)
print(rank)
得到的结果如下
tensor([[-1.2281, 0.6057, -1.1720, 1.1262, -0.1582]])
tensor([[0, 2, 4, 1, 3]]) #idx
tensor([[0, 3, 1, 4, 2]]) #rank
在第一次排序后,idx给出的是从小到大排列后,元素在ts中的位置,再次排列后,可以根据元素的大小,来判断张量ts每个元素是第几小。
由rank可知,原张量ts第0个元素对应的是0,即第0小(最小),第1个元素对应的是3,则第3小,以此类推。
- 原理
第一次sort后,我们知道[0, 2, 4, 1, 3]指代ts中第几个元素,且下标[0, 1, 2, 3, 4]对应第几小,[0, 2, 4, 1, 3]再经过一次sort后,会得到[0, 1, 2, 3, 4],相当于恢复原始张量ts的排布,而此时的rank是在idx中的下标,相当于对应第几小,从而得到ts张量中每个元素在由小到大排序后所在的位置。
若想获取由大到小排序后的位置,仅需将第一个sort改为ts.sort(dim=1, descending=True),程序如下
import torch
ts = torch.randn((1, 5), dtype=torch.float)
print(ts)
srt1, idx = ts.sort(dim=1, descending=True)
print(idx)
srt2, rank = idx.sort(dim=1)
print(rank)
结果如下
tensor([[ 0.7378, 0.7733, -0.8082, 0.8138, 0.0559]])
tensor([[3, 1, 0, 4, 2]])
tensor([[2, 1, 4, 0, 3]])