topk(num,dim=1)
>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291, 1.4127, -2.2464, 0.8932],
[-0.4483, -0.3458, 0.8384, 1.9580],
[-0.5633, -2.2806, 0.6278, 1.3552]])
在行上取一个最大值
>>> topkv,topki=output.topk(1,1)
>>> topkv
tensor([[1.4127],
[1.9580],
[1.3552]])
>>> topki
tensor([[1],
[3],
[3]])
在行上取前两个最大值
>>> topkv,topki=output.topk(2,1)
>>> topkv
tensor([[1.4127, 0.8932],
[1.9580, 0.8384],
[1.3552, 0.6278]])
>>> topki
tensor([[1, 3],
[3, 2],
[3, 2]])
topk(num,dim=0)
>>> output=torch.randn(3,4)
>>> output
tensor([[-1.9291, 1.4127, -2.2464, 0.8932],
[-0.4483, -0.3458, 0.8384, 1.9580],
[-0.5633, -2.2806, 0.6278, 1.3552]])
在列上取一个最大值
>>> topkv,topki=output.topk(1,0)
>>> topkv
tensor([[-0.4483, 1.4127, 0.8384, 1.9580]])
>>> topki
tensor([[1, 0, 1, 1]])
在列上取两个最大值
>>> topkv,topki=output.topk(2,0)
>>> topkv
tensor([[-0.4483, 1.4127, 0.8384, 1.9580],
[-0.5633, -0.3458, 0.6278, 1.3552]])
>>> topki
tensor([[1, 0, 1, 1],
[2, 1, 2, 2]])