官方文档:https://pytorch.org/docs/stable/generated/torch.topk.html?highlight=topk#torch.topk
由于numpy本身是没有提供topk方法的,自己写一个有时候又很蛋疼(懒得写),在这种情况下便可以考虑pytorch提供的topk:
values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
- input:输入tensor
- k:字面意思
- dim:按哪一维进行排序
- sorted:返回的元素是否要排序
- values:最大的k个值
- indices:最大值所对应的下标
一般来说,很多情况下单纯就是馋这个indices,这里举一个实际中可能遇到的例子。比如说,我们现在有一个图片的list:
image_name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']
每张图片对应有一个分数:
score = [0.57, 0.59, 0.38, 0.77, 0.25]
要将分数最高的三张图片给选出来。一种方法是利用面向对象的思想将其建模为一个class然后定义规则进行排序,这里我们利用topk的方式如下:
import numpy as np
import torch
name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']
score = [0.57, 0.59, 0.38, 0.77, 0.25]
name = np.array(name)
score = torch.Tensor(score)
val, idx = torch.topk(score, 3)
idx = idx.numpy()
print(name[idx])
输出结果为:
['D.jpg' 'B.jpg' 'A.jpg']