根据Pytorch中的手册可以看到,topk()方法用于返回输入数据中特定维度上的前k个最大的元素。
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
参数说明:
input -> 输入tensor
k -> 前k个
dim -> 默认为输入tensor的最后一个维度
sorted -> 是否排序
largest -> False表示返回第k个最小值
>>> x = torch.arange(1., 6.)
>>> x
tensor([1.,2.,3.,4.,5.])
>>> torch.topk(x, 3)
tensor([5.,4.,3.]), indices=tensor([4,3,2])