1 作用
对一个 tensor 中的元素取它的前 K 个元素 (从大到小排列)
2 参数介绍
import torch
torch.topk(input, k, dim = None, largest = True, sorted = True, *, out = None)
input(Tensor) : 输入的张量
k(int) : 前 k 个大小中的 k
dim(int, optional) : 需要进行排序的维度, dim = 0 表示按照列来排序, dim = 1 表示按照行来排序, 默认情况下, dim = 1
largest(bool, optional) : 控制是否返回最大值或最小值
sorted(bool, optional) : 控制是否对元素进行排序后再返回
out(tuple,可选):(Tensor,LongTensor)的输出元组,可以可选地指定用作输出缓冲区
3 注意事项
返回给定维度上给定输入张量的k个最大元素。
如果未给出dim,则选择输入的最后一个维度。
如果maximum为False,则返回k个最小元素。
返回一个(值,索引)的命名元组,其中包含给定维度dim中输入张量每行的最大k个元素的值和索引。
如果为True,则布尔选项将确保返回的k个元素本身已排序
4 示例
import torch
a = torch.randn(3, 3)
print("a : ", a)
b = torch.topk(a, 2, dim = 1, largest = False)
print("b :", b)
>>> a : tensor([[-0.6474, -0.0939, 1.3639],
[-0.0297, 0.6471, -0.2255],
[-1.2431, -0.3386, 1.9692]])
>>> b : torch.return_types.topk(
>>> values=tensor([[-0.6474, -0.0939],
[-0.2255, -0.0297],
[-1.2431, -0.3386]]),
>>> indices=tensor([[0, 1],
[2, 0],
[0, 1]]))