-
函数介绍
a.topk()求a中的最大值或最小值,返回两个值,一个是a中的值(最大或最小),一个是这个值的索引。 -
代码示例
>>> import torch
>>> a=torch.randn((3,5))
>>> a
tensor([[-0.4790, -0.6308, 0.2370, 0.0380, -0.0579],
[-0.6712, -3.5483, -0.2370, -0.8658, 0.4145],
[-1.4126, -0.8786, -0.4216, -0.0878, -1.4015]])
>>> _,pre=a.topk(1,dim=1,largest=True)
>>> pre
tensor([[2],
[4],
[3]])
>>> _
tensor([[ 0.2370],
[ 0.4145],
[-0.0878]])
>>> _,pre=a.topk(1,dim=1,largest=False)
>>> pre
tensor([[1],
[1],
[0]])
>>> _
tensor([[-0.6308],
[-3.5483],
[-1.4126]])
dim=1,为按行求最大最小值,largest为Ture,求最大值,largest=False,求最小值。