pytorch中topk()用法的测试与个人理解
参数介绍:
直接官网的介绍topk()
input:就是输入的tensor,也就是要取topk的张量
k:就是取前k个最大的值。
dim:就是在哪一维来取这k个值。
lagest:默认是true表示取前k大的值,false则表示取前k小的值
sorted:是否按照顺序输出,默认是true。
out : 可选输出张量 (Tensor, LongTensor)
直接上代码:
首先研究一下dim和k这两个最重要的参数:
import torch
seed = 0
torch.manual_seed(see
原创
2021-09-25 10:21:07 ·
1187 阅读 ·
0 评论