2025/10/22—PyTorch Tensor.topk

AgenticAI·十月创作之星挑战赛 10w+人浏览 1k人参与

PyTorch Tensor.topk 笔记(用于 Top-k 准确率)

1) 函数原型

torch.Tensor.topk(k, dim=None, largest=True, sorted=True, *, out=None)

作用:在给定维度上返回张量中前 k 个元素的 索引


2) 关键参数(结合分类输出 output ∈ R^{N×C}

  • k / maxk:取前 k 个元素的数量。

  • dim=1:沿类别维度(每一行一个样本)取 Top-k。

  • largest=True:取最大值(分类里通常如此;False 取最小值)。

  • sorted=True:返回的 k 个结果按降序排列(便于切片 Top-1、Top-5)。


3) 返回值

values, indices = output.topk(k, dim=1, largest=True, sorted=True)
  • values:形状 [N, k],每个样本的前 k 个分数。

  • indices(常记为 pred):形状 [N, k],对应前 k 个分数的类别索引

实际计算 Top-k 准确率时,通常仅需 indices


4) 例子

output = torch.tensor([
    [0.8, 0.1, 0.5, 0.3, 0.2],  # 样本1
    [0.4, 0.9, 0.2, 0.7, 0.3]   # 样本2
])
values, pred = output.topk(3, dim=1, largest=True, sorted=True)

# values:
# tensor([[0.8, 0.5, 0.3],
#         [0.9, 0.7, 0.4]])

# pred (类别索引):
# tensor([[0, 2, 3],
#         [1, 3, 0]])

5) 在 Top-k 准确率中的用法骨架

# output: [N, C],target: [N] (int64 类别索引)
maxk = max(topk)                  # 比如 topk = (1, 5)
_, pred = output.topk(maxk, 1, True, True)  # pred: [N, maxk]
pred = pred.t()                   # [maxk, N]
correct = pred.eq(target.view(1, -1).expand_as(pred))  # [maxk, N] bool

accs = []
for k in topk:
    # 用 reshape 比 view 更稳妥(避免非连续内存问题)
    correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)[0]
    accs.append(correct_k * 100.0 / output.size(0))

6) 常见注意点

  • maxk ≤ C(类别数),否则 topk 会报错。

  • target.dtype == torch.longint64),否则与 pred.eq(...) 不匹配。

  • sorted=True 便于“先算最大 k,再切前 k”;但不是必须。

  • 若遇到 view 报非连续错误,改用 .reshape(-1)

  • pred 转置为 [maxk, N] 是为了按行切前 k 并与 target 对齐比较。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值