torch.topk()
是 PyTorch 中的一个非常有用的函数,它用于返回输入张量中每个元素沿指定维度的最大 k
个元素及其索引。这个函数在很多场景中都非常有用,比如找到每个类别中得分最高的几个元素,或者在处理自然语言任务时选择概率最高的几个单词等。
函数签名
python复制代码
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor) |
参数
- input (Tensor) – 输入张量。
- k (int) – 每个元素要返回的最大(或最小)值的数量。
- dim (int, optional) – 沿着此维度查找最大(或最小)值。如果未指定,则默认将输入张量视为1D向量并返回整个张量的前k个最大(或最小)值。
- largest (bool, optional) – 如果为True,则返回最大的k个值;如果为False,则返回最小的k个值。默认为True。
- sorted (bool, optional) – 如果为True,则返回的值将按降序排列;如果为False,则返回的值的顺序将不确定。注意,即使
sorted=False
,如果largest=True
,则较大的值仍然会倾向于排在前面,但不一定是完全排序的。默认为True。 - out (tuple, optional) – 一个可选的输出元组,其中第一个元素是一个Tensor,用于存储结果值;第二个元素是一个LongTensor,用于存储这些值的索引。这通常用于减少内存分配。
返回值
- values (Tensor) – 沿着指定维度返回的最大(或最小)的k个元素。
- indices (LongTensor) – 返回的元素的索引。
示例
python复制代码
import torch | |
# 创建一个随机张量 | |
x = torch.randn(3, 4) | |
print("Original tensor:") | |
print(x) | |
# 找到每行的前2个最大值及其索引 | |
top_vals, top_indices = torch.topk(x, 2, dim=1) | |
print("Top 2 values:") | |
print(top_vals) | |
print("Indices of top 2 values:") | |
print(top_indices) | |
# 找到整个张量中的前2个最大值及其索引(不指定dim) | |
top_vals_global, top_indices_global = torch.topk(x, 2) | |
print("Global top 2 values:") | |
print(top_vals_global) | |
print("Indices of global top 2 values:") | |
print(top_indices_global) |
在这个示例中,我们首先创建了一个3x4的随机张量,并分别找到了每行中的前2个最大值及其索引,以及整个张量中的前2个全局最大值及其索引。注意,当不指定dim
时,torch.topk()
会将输入张量视为一个1D向量。