Topk
含义
topk()方法用于返回输入数据中特定维度上的前k个最大的元素。
操作
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
参数说明:
- input -> 输入tensor
- k -> 前k个
- dim -> 默认为输入tensor的最后一个维度
- sorted -> 是否排序
- largest -> False表示返回第k个最小值
实践
>>> x = torch.arange(1., 6.)
>>> x
tensor([1.,2.,3.,4.,5.])
>>> torch.topk(x, 3)
tensor([5.,4.,3.]), indices=tensor([4,3,2])
总结
使用torch.topk可以指定特定维度、特定前N个最大/最小值,并输出其值和索引。