torch.where(),np.where()的两种用法,以及np.argwhere()寻找张量(tensor)和数组中为0的索引

1.torch.where()

torch.where()有两种用法,

1.当输入参数为三个时,即torch.where(condition, x, y),返回满足 x if condition else y的tensor,注意x,y必须为tensor

2.当输入参数为一个时,即torch.where(condition),返回满足condition的tensor索引的元组(tuple)

代码示例

torch.where(condition, x, y)

代码

import torch
import numpy as np

# 初始化两个tensor
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])

# 寻找满足x中大于3的元素,否则得到y对应位置的元素
arr0 = torch.where(x>=3, x, y) #输入参数为3个

print(x, '\n', y)
print(arr0, '\n', type(arr0))

结果

>>> x
tensor([[1, 2, 3, 0, 6],
        [4, 6, 2, 1, 0],
        [4, 3, 0, 1, 1]])
>>> y
tensor([[0, 5, 1, 4, 2],
        [5, 7, 1, 2, 9],
        [1, 3, 5, 6, 6]])

>>> arr0
tensor([[0, 5, 3, 4, 6],
        [4, 6, 1, 2, 9],
        [4, 3, 5, 6, 6]])

>>> type(arr0)
<class 'torch.Tensor'>

arr0的类型为<class 'torch.Tensor'>

torch.where(condition)

以寻找tensor中为0的索引为例

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])

# 返回x中0元素的索引
index0 = torch.where(x==0) # 输入参数为1个

print(index0,'\n', type(index0))

结果

>>> index0
(tensor([0, 1, 2]), tensor([3, 4, 2])) 

>>> type(index0)
<class 'tuple'>

其中[0, 1, 2]是0元素坐标的行索引,[3, 4, 2]是0元素坐标的列索引,注意,最终得到的是tuple类型的返回值,元组中包含了tensor

2.np.where()

np.where()用法与torch.where()用法类似,也包括两种用法,但是不同的是输入值类型和返回值的类型

代码示例

np.where(condition, x, y)和np.where(condition),输入x,y可以为非tensor

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])

arr1 = np.where(x>=3, x, y) # 输入参数为3个

index0 = torch.where(x==0) # 输入参数为1个

print(arr1,'\n',type(arr1))
print(index1,'\n', type(index1))

结果

>>> arr1
[[0 5 3 4 6]
 [4 6 1 2 9]
 [4 3 5 6 6]]

>>> type(arr1)
<class 'numpy.ndarray'>

>>> index1
(array([0, 1, 2]), array([3, 4, 2])) 

>>> type(index1)
<class 'tuple'>

注意,np.where()和torch.where()的返回值类型不同

3.np.argwhere(condition)

寻找符合contion的元素索引

代码示例

代码

import torch
import numpy as np
x = torch.tensor([
    [1,2,3,0,6],
    [4,6,2,1,0],
    [4,3,0,1,1]
])
y = torch.tensor([
    [0,5,1,4,2],
    [5,7,1,2,9],
    [1,3,5,6,6]
])


index2 = np.argwhere(x==0) # 寻找元素为0的索引

print(index2,'\n', type(index2))

结果

>>> index2
tensor([[0, 1, 2],
        [3, 4, 2]]) 

>>> type(index2)
<class 'torch.Tensor'>

注意返回值的类型

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值