目的:求取 tensor 中某个 dim 的排序前 k 个值 (val) 以及其索引 (index)。
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
#input:tensor数据
#k:排序后的前k个数据
#dim:沿着某个维度
#largest:True是指从大到小,否则,从小到大
#sorted:True返回的结果
示例:
import torch
if __name__=="__main__":
pred = torch.rand((4, 5))
print(pred)
print("------------k=1------------------")
vals, indices = pred.topk(k=1, dim=1, largest=True, sorted=True)
print(indices)
print("------------k=2------------------")
vals, indices = pred.topk(k=2, dim=1, largest=True, sorted=True)
print(indices)
#output:
tensor([[0.2219, 0.9817, 0.7909, 0.7659, 0.1657],
[0.6779, 0.9653, 0.1959, 0.3108, 0.1755],
[0.9107, 0.5243, 0.2525, 0.1543, 0.4314],
[0.5417, 0.0409, 0.5777, 0.3693, 0.2606]])
------------k=1------------------
tensor([[1],
[1],
[0],
[2]])
------------k=2------------------
tensor([[1, 2],
[1, 0],
[0, 1],
[2, 0]])