torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
topk的原型如上:
其中k是保留的k个值,largest=True意味着选取最大的,sorted=True是指将返回结果排序
topk返回的是一个tuple,第一个元素指返回的具体值,第二个元素指返回值的index
直接贴代码
import torch
x = torch.rand(2, 3, 3)
y = x.topk(2, largest = True, sorted = True)
结果如下
>>> x
tensor([[[0.0858, 0.4492, 0.4394],
[0.2662, 0.5704, 0.5212],
[0.1720, 0.8962, 0.4634]],
[[0.5535, 0.5748, 0.0194],
[0.0723, 0.7901, 0.4427],
[0.7804, 0.8924, 0.3323]]])
>>> y
(tensor([[[0.4492, 0.4394],
[0.5704, 0.5212],
[0.8962, 0.4634]],
[[0.5748, 0.5535],
[0.7901, 0.4427],
[0.8924, 0.7804]]]), tensor([[[1, 2],
[1, 2],
[1, 2]],
[[1, 0],
[1, 2],