values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
- input:输入tensor
- k:字面意思
- dim:按哪一维进行排序
- sorted:返回的元素是否要排序
- values:最大的k个值
- indices:最大值所对应的下标
一般来说,很多情况下单纯就是馋这个indices,这里举一个实际中可能遇到的例子。比如说,我们现在有一个图片的list:
image_name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']
每张图片对应有一个分数:
score = [0.57, 0.59, 0.38, 0.77, 0.25]
要将分数最高的三张图片给选出来。一种方法是利用面向对象的思想将其建模为一个class然后定义规则进行排序,这里我们利用topk的方式如下:
import numpy as np
import torch
name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']
score = [0.57, 0.59, 0.38, 0.77, 0.25]
name = np.array(name)
score = torch.Tensor(score)
val, idx = torch.topk(score, 3)
idx = idx.numpy()
print(name[idx])
输出结果为:
['D.jpg' 'B.jpg' 'A.jpg']
————————————————
版权声明:本文为CSDN博主「xiongxyowo」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_40714949/article/details/123017649
进阶请看链接2与3
参考资料
Pytorch torch.topk()的简单用法_xiongxyowo的博客-CSDN博客_torch.topk()