pytorch torch.topk用法
import torch
pred = torch.randn((4, 5))
print('pred是', pred, "\n")
values, indices = pred.topk(2, dim=1, largest=True, sorted=True)
print("values是", values, "\n")
print("indices是", indices, "\n")
输出:
pred是 tensor([[-0.7044, -0.3443, 1.2655, -0.8944, -0.0917],
[-0.9605, -0.1830, -0.9964, -0.0696, -2.1732],
[-0.2562, 0.0429, 1.1153, -0.0081, -1.8422],
[ 1.4038, 0.3336, 0.9309, 0.2830, -0.4532]])
values是 tensor([[ 1.2655, -0.0917],
[-0.0696, -0.1830],
[ 1.1153, 0.0429],
[ 1.4038, 0.9309]])
indices是 tensor([[2, 4],
[3, 1],
[2, 1],
[0, 2]])