pytorch tensor 获取指定值的indices

import torch

x  = torch.randn(1, 3, 6, 6)
y = torch.zeros(x.shape).to(x.device)
y[x >= 0.5] = 1
z = (y == 1).nonzero(as_tuple = False) 
print(z)

# z is the indices you need.

要取出PyTorch张量中不为0的,并将它们合并到一个新的张量中,可以按照以下步骤进行: 1. 导入PyTorch库:首先需要导入PyTorch库,以便使用其中的函数和类。 2. 创建张量:可以使用torch.Tensor()函数创建一个张量,也可以根据实际需求选择其他合适的张量创建方式。 3. 获取不为0的元素:使用张量的非零索引函数(如nonzero())可以获取张量中所有不为0的元素的索引。例如,若张量名为"tensor",则可以通过tensor.nonzero()获取不为0的元素索引。 4. 提取不为0的:通过索引将不为0的从原始张量中提取出来。例如,可以使用tensor[indices]将不为0的提取出来,其中indices是通过nonzero()函数获取的不为0元素的索引。 5. 合并提取的:将提取的不为0的使用torch.cat()函数进行合并。可以使用torch.cat(tensor_list, dim)来将多个张量在指定维度上进行合并。其中,tensor_list是一个张量的列表,dim是要在哪个维度上进行合并。 具体代码如下所示: ```python import torch # 创建张量 tensor = torch.tensor([[1, 0, 3], [0, 5, 0], [7, 0, 9]]) # 获取不为0的元素索引 nonzero_indices = tensor.nonzero() # 提取不为0的 nonzero_values = tensor[nonzero_indices] # 合并提取的 merged_tensor = torch.cat(nonzero_values, dim=0) print(merged_tensor) ``` 这样就可以获取并合并原始张量中的所有不为0的到一个新的张量中。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值