torch.topk(input, k, dim=n, largest=True, sorted=True, out=None)
其中:
input
-> 输入tensor
k
-> 前k
个sorted
-> 是否排序largest
-> False
表示返回第k个最小值
以上四个都好理解,本文主要讲的是dim的取值,并举例说明。
首先创建数组:
import torch
pred=torch.tensor([[1,5,3,9],
[4,9,2,8],
[6,4,1,7],
[2,8,5,0]])
当dim的取值为0,输出前三个数组的最大值。
values=pred.topk(3, dim=0, largest=True, sorted=True)
print(values)
运行结果如下:
torch.return_types.topk(
values=tensor([[6, 9, 5, 9],
[4, 8, 3, 8],
[2, 5, 2, 7]]),
indices=tensor([[2, 1, 3, 0],
[1, 3, 0, 1],
[3, 0, 1, 2]]))
可以看到,当dim=0时,原数组按列取最大的三个数。
当dim的取值为1,输出前三个数组的最大值。
values=pred.topk(3, dim=1, largest=True, sorted=True)
print(values)
运行结果如下:
torch.return_types.topk(
values=tensor([[9, 5, 3],
[9, 8, 4],
[7, 6, 4],
[8, 5, 2]]),
indices=tensor([[3, 1, 2],
[1, 3, 0],
[3, 0, 1],
[1, 2, 0]]))
可以看到,当dim=1时,原数组按行取最大的三个数。