torch.topk()使用方法及其示例

1.作用

取一个tensor的topk元素

2.使用方法

dim=0表示按照列求topn,dim=1表示按照行求topn,None情况下,dim=1.

任务一:

取top1:

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])

任务二:
按行取出topk,将小于topk的置为inf

pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
print(pred)
top_k = 2
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷
print(pred)

输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145,  0.4053],
        [ 0.7265,  1.4164,  1.3443,  1.2035,  1.8823],
        [-0.4451,  0.1673,  1.2590, -2.0757,  1.7255],
        [ 0.2021,  0.3041,  0.1383,  0.3849, -1.6311]])
tensor([[4],
        [4],
        [4],
        [3]])
tensor([[0.4053],
        [1.8823],
        [1.7255],
        [0.3849]])
tensor([[ True, False,  True,  True, False],
        [ True, False,  True,  True, False],
        [ True,  True, False,  True, False],
        [ True, False,  True, False,  True]])
tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],
        [   -inf,  1.4164,    -inf,    -inf,  1.8823],
        [   -inf,    -inf,  1.2590,    -inf,  1.7255],
        [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

参考:

https://blog.csdn.net/qq_34914551/article/details/103738160
https://blog.csdn.net/u014264373/article/details/86525621
https://blog.csdn.net/weixin_45062709/article/details/102711885

  • 14
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值