介绍
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)
- 功能:返回给定输入张量在给定维度上的前k个最大元素
- 如果没有给出dim,则选择输入的最后一个维度。
- 如果’largest =False’ 则返回最小的k个元素
- 函数返回:返回一个由(值、索引)组成的命名元组,其中索引是原始输入张量中元素的索引
- 如果’sorted=True’则返回从大到小排序之后的元素,以及对应的索引
Parameters
input (Tensor) – 输入的张量
k (int) – the k in “top-k”
dim (int, optional) – 排序的维度
largest (bool, optional) – 控制是否返回最大或最小的元素
sorted (bool, optional) – 控制是否按排序顺序返回元素
例
>>> import torch
>>> x = torch.arange(1., 6.)
>>> x
tensor([1., 2., 3., 4., 5.])
>>> topk_list = torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]),
indices=tensor([4, 3, 2]))
>>> topk_list[0]
tensor([5., 4., 3.])
>>> topk_list[1]
tensor([4, 3, 2])
从上边的列子我们可以看到,可以使用top_list[0]来获取前3个最大的经过由大到小排序的元素
使用top_[1]来获取这三个元素对应的索引列表。还是很方便的