现有1矩阵:
>>> a
tensor([[-0.5882, 0.1056, 2.3118],
[-0.2930, -1.2162, -0.6995],
[ 0.5129, -0.3037, -0.8628],
[-0.4171, 0.0443, 1.1761]])
现在我们要得到每列的最大的2个数的mask
_, indices = a.sort(dim=0)
>>> indices
tensor([[0, 1, 2],
[3, 2, 1],
[1, 3, 3],
[2, 0, 0]])
先sort,dim=0表示按列sort,我们可以看到第一行,第一个数为0,表示第一列第0个数最大,第3个数为2,表示第三列第2个数最大,也就是说,我们现在拿到的是从上往下是每一列最大的数的序号,我们希望的是,最大的数所在的位置的值代表它的排序,也就是说,希望(2,3)位置的值为0,(0,3)的位置为3,怎么办呢?再sort一次
_,indices2 = indices.sort(dim=0)
tensor([[0, 3, 3],
[2, 0, 1],
[3, 1, 0],
[1, 2, 2]])
indices2<k#即为topk的mask