两次torch.sort求元素在排序后的位置
两次torch.sort求元素在排序后的位置目的:利用两次sort求原张量中每个元素第几小(大),测试用例如下import torchts = torch.randn((1, 5), dtype=torch.float)print(ts)srt1, idx = ts.sort(dim=1)print(idx)srt2, rank = idx.sort(dim=1) print(rank)。。。。得到的结果如下tensor([[-1.2281, 0.6057, -1.1720,
原创
2020-07-22 21:37:34 ·
1024 阅读 ·
0 评论