Hello,大家好,今天主要为大家讲一下pytorch的topk函数。
torch.topk
(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
该函数的作用就是沿着指定的dim(维度)返回input(输入张量)的k个最大的元素。
如果dim没有指定,默认值是input的最后一个维度。如果largest为False,那么返回的是input的k个最小的元素。函数的返回结果是(values, indices)的具名元祖,其中indices是返回的k个元素在原张量中的索引值。如果bool型选项sorted为True,那么返回的k个元素的顺序是排好的。
例子:
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
祝大家科研顺利:)