PyTorch学习(1):torch.meshgrid的使用-CSDN博客
PyTorch学习(2):torch.device-CSDN博客
目录
1. 简述
PyTorch的topk函数用于返回Tensor中的前k个元素及其对应的索引。这个函数非常有用,尤其是在处理分类问题时,可以用来找出每个样本最可能属于的k个类别。
在实际的部署过程中,topk非常重要,可以实现快速的筛选和过滤。
2. 原型
import torch
torch.topk(input, k, dim = None, largest = True, sorted = True, *, out = None)
对一个 tensor 在指定的维度按照指定的排序取它的前 K 个元素 (默认从大到小排序)
参数:
- input(Tensor) : 输入的张量;
- k(int) : 前 k 个大小中的 k;
- dim(int, optional) : 需要进行排序的维度, dim = 0 表示按照列来排序, dim = 1 表示按照行来排序, 默认情况下, dim = -1,即在最后一个维度进行操作;
- largest(bool, optional) : 控制是否返回最大值或最小值;
- sorted(bool, optional) : 控制是否对元素进行排序后再返回;
- out(tuple,可选):(Tensor,LongTensor)的输出元组,可以可选地指定用作输出缓冲区;
3. 例程
在下面代码中,pred是一个包含4个样本和5个类别的Tensor。使用topk函数,我们可以分别获取每个样本最可能属于的类别(k=1)和前两个最可能属于的类别(k=2)。values变量包含了选中的元素,而indices变量包含了这些元素在原始Tensor中的索引。
import torch
# 创建一个随机Tensor
pred = torch.randn((4, 5))
# 获取每个样本最可能属于的类别(k=1)
values, indices = pred.topk(1, dim=1, largest=True, sorted=True)
print("Values:", values)
print("Indices:", indices)
# 获取每个样本最可能属于的前两个类别(k=2)
values, indices = pred.topk(2, dim=1, largest=True, sorted=True)
print("Values:", values)
print("Indices:", indices)