tensor.topk()
是 PyTorch 中的一个函数,用于获取张量(Tensor)中的最大的 k
个元素及其对应的索引。这在处理分类问题和排序问题时很常用。
下面是 tensor.topk()
的基本使用方法的示例,以及对其接口参数的详细说明:
接口参数
topk()
函数的典型接口如下:
tensor_x.topk(k, dim=-1, largest=True, sorted=True)
tensor_x
:输入的张量。k
:要返回的元素数量。dim
:要进行操作的维度。默认为最后一个维度。largest
:如果设置为True
,返回最大的k
个元素;如果为False
,则返回最小的k
个元素。默认为True
。sorted
:如果设置为True
,返回的k
个元素将按照顺序排序;如果为False
,则不保证排序。默认为True
。
用法示例
# 创建一个形状为 (3, 4, 5) 的随机矩阵
x_random = torch.randn(3, 4, 5)
# 在最后一个维度上找最大的2个元素
values, indices = x_random.topk(2, dim=-1)