得到矩阵每列的topk的mask

现有1矩阵:

>>> a                                                                            
tensor([[-0.5882,  0.1056,  2.3118],                                                                                                                          
 [-0.2930, -1.2162, -0.6995],                                                                                                                         
 [ 0.5129, -0.3037, -0.8628],                                       
 [-0.4171,  0.0443,  1.1761]])

现在我们要得到每列的最大的2个数的mask

_, indices = a.sort(dim=0)
>>> indices 
tensor([[0, 1, 2],
[3, 2, 1],
[1, 3, 3],   
[2, 0, 0]])

先sort,dim=0表示按列sort,我们可以看到第一行,第一个数为0,表示第一列第0个数最大,第3个数为2,表示第三列第2个数最大,也就是说,我们现在拿到的是从上往下是每一列最大的数的序号,我们希望的是,最大的数所在的位置的值代表它的排序,也就是说,希望(2,3)位置的值为0,(0,3)的位置为3,怎么办呢?再sort一次

_,indices2 = indices.sort(dim=0)
tensor([[0, 3, 3],
[2, 0, 1],      
[3, 1, 0],                                                
[1, 2, 2]])
indices2<k#即为topk的mask                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值