torch.topk() 函数

torch.topk() 是 PyTorch 中的一个非常有用的函数,它用于返回输入张量中每个元素沿指定维度的最大 k 个元素及其索引。这个函数在很多场景中都非常有用,比如找到每个类别中得分最高的几个元素,或者在处理自然语言任务时选择概率最高的几个单词等。

函数签名

 

python复制代码

torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)

参数

  • input (Tensor) – 输入张量。
  • k (int) – 每个元素要返回的最大(或最小)值的数量。
  • dim (int, optional) – 沿着此维度查找最大(或最小)值。如果未指定,则默认将输入张量视为1D向量并返回整个张量的前k个最大(或最小)值。
  • largest (bool, optional) – 如果为True,则返回最大的k个值;如果为False,则返回最小的k个值。默认为True。
  • sorted (bool, optional) – 如果为True,则返回的值将按降序排列;如果为False,则返回的值的顺序将不确定。注意,即使sorted=False,如果largest=True,则较大的值仍然会倾向于排在前面,但不一定是完全排序的。默认为True。
  • out (tuple, optional) – 一个可选的输出元组,其中第一个元素是一个Tensor,用于存储结果值;第二个元素是一个LongTensor,用于存储这些值的索引。这通常用于减少内存分配。

返回值

  • values (Tensor) – 沿着指定维度返回的最大(或最小)的k个元素。
  • indices (LongTensor) – 返回的元素的索引。

示例

 

python复制代码

import torch
# 创建一个随机张量
x = torch.randn(3, 4)
print("Original tensor:")
print(x)
# 找到每行的前2个最大值及其索引
top_vals, top_indices = torch.topk(x, 2, dim=1)
print("Top 2 values:")
print(top_vals)
print("Indices of top 2 values:")
print(top_indices)
# 找到整个张量中的前2个最大值及其索引(不指定dim)
top_vals_global, top_indices_global = torch.topk(x, 2)
print("Global top 2 values:")
print(top_vals_global)
print("Indices of global top 2 values:")
print(top_indices_global)

在这个示例中,我们首先创建了一个3x4的随机张量,并分别找到了每行中的前2个最大值及其索引,以及整个张量中的前2个全局最大值及其索引。注意,当不指定dim时,torch.topk()会将输入张量视为一个1D向量。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值