torch.topk
是 PyTorch 中的一个函数,用于在指定维度上获取张量中最大的 k 个值及其对应的索引。
该函数的用法如下:
values, indices = torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
input:要操作的输入张量。
k:要获取的最大值的个数。
dim:指定在哪个维度上获取最大值,默认为输入张量的最后一个维度。
largest:如果为 True,则获取最大的 k 个值;如果为 False,则获取最小的 k 个值。默认为 True。
sorted:如果为 True,则返回的最大值和对应的索引按照值排序;如果为 False,则返回的最大值和对应的索引的顺序不确定。默认为 True。
out:可选参数,用于指定输出张量。
这个函数对于需要在张量中找到最大值或最小值的任务非常有用,比如在排序、选择或计算统计信息时。