两次torch.sort求元素在排序后的位置

两次torch.sort求元素在排序后的位置

  • 目的:
    利用两次sort求原张量中每个元素第几小(大),测试用例如下
import torch

ts = torch.randn((1, 5), dtype=torch.float)
print(ts)
srt1, idx = ts.sort(dim=1)
print(idx)
srt2, rank = idx.sort(dim=1) 
print(rank)

得到的结果如下

tensor([[-1.2281,  0.6057, -1.1720,  1.1262, -0.1582]])
tensor([[0, 2, 4, 1, 3]])   #idx
tensor([[0, 3, 1, 4, 2]])   #rank

在第一次排序后,idx给出的是从小到大排列后,元素在ts中的位置,再次排列后,可以根据元素的大小,来判断张量ts每个元素是第几小。
由rank可知,原张量ts第0个元素对应的是0,即第0小(最小),第1个元素对应的是3,则第3小,以此类推。

  • 原理
    第一次sort后,我们知道[0, 2, 4, 1, 3]指代ts中第几个元素,且下标[0, 1, 2, 3, 4]对应第几小,[0, 2, 4, 1, 3]再经过一次sort后,会得到[0, 1, 2, 3, 4],相当于恢复原始张量ts的排布,而此时的rank是在idx中的下标,相当于对应第几小,从而得到ts张量中每个元素在由小到大排序后所在的位置。

若想获取由大到小排序后的位置,仅需将第一个sort改为ts.sort(dim=1, descending=True),程序如下

import torch

ts = torch.randn((1, 5), dtype=torch.float)
print(ts)
srt1, idx = ts.sort(dim=1, descending=True)
print(idx)
srt2, rank = idx.sort(dim=1) 
print(rank)

结果如下

tensor([[ 0.7378,  0.7733, -0.8082,  0.8138,  0.0559]])
tensor([[3, 1, 0, 4, 2]])
tensor([[2, 1, 4, 0, 3]])
  • 6
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值