pytorch topk 保持维度和位置 置零

pytorch的topk能够返回最大的k个值,现在假设有一个[2,3,4]的权重矩阵,如果我们需要在第三个维度找出最大的两个值,(并保持权重矩阵的维度不变,且最大值的位置也不变),topk就不是很好用了,以下代码能解决这个问题:

import torch
import numpy as np
if __name__ == "__main__":
    x=torch.tensor(np.arange(1,25)).reshape(2,3,4)
    print(x)
    # k=2表示选择两个最大值
    a,_=x.topk(k=2,dim=2)
    # 要加上values,否则会得到一个包含values和indexs的对象
    a_min=torch.min(a,dim=-1).values
    # repeat里的4和x的最后一维相同
    a_min=a_min.unsqueeze(-1).repeat(1,1,4)
    ge=torch.ge(x,a_min)
    # 设置zero变量,方便后面的where操作
    zero=torch.zeros_like(x)
    result=torch.where(ge,x,zero)
    print(result)

输出是:

# 原矩阵
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12]],

        [[13, 14, 15, 16],
         [17, 18, 19, 20],
         [21, 22, 23, 24]]], dtype=torch.int32)
# 每个维度只保留两个最大值
tensor([[[ 0,  0,  3,  4],
         [ 0,  0,  7,  8],
         [ 0,  0, 11, 12]],

        [[ 0,  0, 15, 16],
         [ 0,  0, 19, 20],
         [ 0,  0, 23, 24]]], dtype=torch.int32)

topk的输出有两个,其他地方可能会派上用场:

    a,b=x.topk(k=2,dim=2)
    print(a)
    print(b)
    
# 输出    
tensor([[[ 4,  3],
         [ 8,  7],
         [12, 11]],

        [[16, 15],
         [20, 19],
         [24, 23]]], dtype=torch.int32)
tensor([[[3, 2],
         [3, 2],
         [3, 2]],

        [[3, 2],
         [3, 2],
         [3, 2]]])

Process finished with exit code 0


PyTorchtopk函数是用于返回输入张量中指定维度上的前k个最大及其对应的索引。它的函数签名为torch.topk(input, k, dim=None, largest=True, sorted=True, out=None),返回一个元组,包含最大的k个组成的张量和它们在输入张量中的索引组成的长整型张量。其中,input是输入张量,k是要返回的最大的个数,dim是指定的维度,largest决定是否返回最大(默认为True),sorted决定是否返回排序的结果(默认为True),out是输出的张量。 例如,如果我们有一个输入张量input为[5, 9, 3, 2, 7],我们想要找出其中最大的3个及其索引,我们可以使用torch.topk(input, 3)。这将返回一个包含[9, 7, 5]的张量和一个包含[1, 4, 0]的长整型张量,分别表示最大的3个和它们在输入张量中的索引。 在具体的代码中,maxk = max(topk)用于获取topk列表中的最大,而output.topk(maxk, 1, True, True)则是对output进行topk操作,返回最大和对应的索引。这种用法可以帮助我们在代码中获取最大的k个及其索引。 总结来说,PyTorchtopk函数可以帮助我们在指定维度上找出输入张量中的最大及其对应的索引。这在许多机器学习和深度学习任务中非常有用。如果想要了解更多关于topk函数的用法,可以参考PyTorch官方中文文档或者一篇介绍topk函数用法的文章。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* *2* [pytorch 中的topk函数](https://blog.csdn.net/u012505617/article/details/103711019)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] - *3* [PyTorchtopk函数的用法详解](https://download.csdn.net/download/weixin_38628150/12856649)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"] [ .reference_list ]
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值