pytorch用法1: 数组排序后复原
主要是利用torch.sort函数里返回的第二个参数index,这个index表示的是排序后的数字在原来数组中的位置。
比如:
l = torch.randint(10,(10,))
a, idx1 = torch.sort(l)
结果为:
l: tensor([3., 3., 8., 7., 9., 9., 7., 4., 5., 1.])
a: tensor([1., 3., 3., 4., 5., 7., 7., 8., 9., 9.])
idx1: tensor([9, 0, 1, 7, 8, 3, 6, 2, 4, 5])
这里的index结果比如是从0到N的,这里的idx1中的值就对应a中的值在b中的位置,比如idx1中的第一个9表示a中的1在数组l中出现在第9个。
如果我们将这个index再进行排序就会发现第二次得到的index保留了一些有用的信息:
b, idx2 = torch.sort(idx1)
b: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
idx2: tensor([1, 2, 7, 5, 8, 9, 6, 3, 4, 0])
idx2的结果表示的是b中的值在idx1中出现的原位置,而idx1中的数的位置表示的正是数组a中的数字在原始数组l中的数的位置。于是如果使用a按照idx2的结果来选择数字,就得到了原始数组l:
a.index_select(0,idx2) #0表示选择的维度
结果正是数组l。
组合起来如下所示:
import torch
l = torch.randint(10,(10,))
a, idx1 = to