PyTorch Tensor.topk 笔记(用于 Top-k 准确率)
1) 函数原型
torch.Tensor.topk(k, dim=None, largest=True, sorted=True, *, out=None)
作用:在给定维度上返回张量中前 k 个元素的 值 和 索引。
2) 关键参数(结合分类输出 output ∈ R^{N×C})
-
k / maxk:取前 k 个元素的数量。
-
dim=1:沿类别维度(每一行一个样本)取 Top-k。
-
largest=True:取最大值(分类里通常如此;
False取最小值)。 -
sorted=True:返回的 k 个结果按降序排列(便于切片 Top-1、Top-5)。
3) 返回值
values, indices = output.topk(k, dim=1, largest=True, sorted=True)
-
values:形状
[N, k],每个样本的前 k 个分数。 -
indices(常记为
pred):形状[N, k],对应前 k 个分数的类别索引。
实际计算 Top-k 准确率时,通常仅需
indices。
4) 例子
output = torch.tensor([
[0.8, 0.1, 0.5, 0.3, 0.2], # 样本1
[0.4, 0.9, 0.2, 0.7, 0.3] # 样本2
])
values, pred = output.topk(3, dim=1, largest=True, sorted=True)
# values:
# tensor([[0.8, 0.5, 0.3],
# [0.9, 0.7, 0.4]])
# pred (类别索引):
# tensor([[0, 2, 3],
# [1, 3, 0]])
5) 在 Top-k 准确率中的用法骨架
# output: [N, C],target: [N] (int64 类别索引)
maxk = max(topk) # 比如 topk = (1, 5)
_, pred = output.topk(maxk, 1, True, True) # pred: [N, maxk]
pred = pred.t() # [maxk, N]
correct = pred.eq(target.view(1, -1).expand_as(pred)) # [maxk, N] bool
accs = []
for k in topk:
# 用 reshape 比 view 更稳妥(避免非连续内存问题)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)[0]
accs.append(correct_k * 100.0 / output.size(0))
6) 常见注意点
-
maxk ≤ C(类别数),否则topk会报错。 -
target.dtype == torch.long(int64),否则与pred.eq(...)不匹配。 -
sorted=True便于“先算最大 k,再切前 k”;但不是必须。 -
若遇到
view报非连续错误,改用.reshape(-1)。 -
pred转置为[maxk, N]是为了按行切前 k 并与target对齐比较。
1万+

被折叠的 条评论
为什么被折叠?



