Pytorch torch.topk()的简单用法

官方文档:https://pytorch.org/docs/stable/generated/torch.topk.html?highlight=topk#torch.topk

由于numpy本身是没有提供topk方法的,自己写一个有时候又很蛋疼(懒得写),在这种情况下便可以考虑pytorch提供的topk:

values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
  • input:输入tensor
  • k:字面意思
  • dim:按哪一维进行排序
  • sorted:返回的元素是否要排序
  • values:最大的k个值
  • indices:最大值所对应的下标

一般来说,很多情况下单纯就是馋这个indices,这里举一个实际中可能遇到的例子。比如说,我们现在有一个图片的list:

image_name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']

每张图片对应有一个分数:

score = [0.57, 0.59, 0.38, 0.77, 0.25]

要将分数最高的三张图片给选出来。一种方法是利用面向对象的思想将其建模为一个class然后定义规则进行排序,这里我们利用topk的方式如下:

import numpy as np
import torch

name = ['A.jpg', 'B.jpg', 'C.jpg', 'D.jpg', 'E.jpg']
score = [0.57, 0.59, 0.38, 0.77, 0.25]

name = np.array(name)
score = torch.Tensor(score)
val, idx = torch.topk(score, 3)
idx = idx.numpy()
print(name[idx])

输出结果为:

['D.jpg' 'B.jpg' 'A.jpg']
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值