torch.where()函数的三种用法

import torch
# torch.where()函数的三种用法:

"""
1.torch.where(bool_tensor1 & bool_tensor2)接受两个元素为布尔值(True/False)组成的张量,返回两个张量组成的元组(turple),其中第一个张量中的元素是两布尔张量
bool_tensor1 和 bool_tensor2 均为True的元素的行索引,第二个张量中的元素是两布尔张量bool_tensor1 和 bool_tensor2均为True的元素的列索引。
[注意:两个布尔张量的shape必须满足广播规则,否则会引发shape在某个维度上不match的问题。]
--例1:
a1 = torch.tensor([[False,  True,  True],
                   [ True, False, False],
                   [False, False,  True]])

b1 = torch.tensor([[True, True, True],
                   [True, True, True],
                   [True, True, True]])
c1 = torch.where(a1 & b1)  # 注意两个布尔张量用&连接,而非, 
print(c1)  # (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))

b1是元素全为True的布尔张量,所以只需考虑a1中为True的元素的行列索引。a1中分别有4个为True的元素,分别是第一行(索引:0)的第二列(索引:1)第三列(索引:2),第二行(索引:1)
第一列(索引0)和第三行(索引:2)第三列(索引:2),写出两个布尔张量a1,a2都为True所有行列索引对(0, 1), (0, 2), (1, 0), (2, 2)。去除第0维度即行方向的索引值组成元组
c1的第一个张量即tensor([0, 0, 1, 2]),取出列方向的索引值组成元组c1的第一个张量即tensor([1, 2, 0, 2]))


--例2:(2.只接受一个布尔张量的特殊情形)
a2 = torch.tensor([True, False, True, False])
print(a2.shape)  # torch.Size([4])
c2 = torch.where(a2)
print(c2)  # (tensor([0, 2]),)
print(len(c2))  # 1

c3 = torch.where(b1)
print(b1.shape)  # torch.Size([3, 3])
print(c3)  # (tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]), tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]))
print(len(c3))  # 2

当torch.where(bool_tensor)函数只接受一个布尔张量,默认第二个张量为全True张量(torch.ones(bool_tensor, dtype=torch.bool)),所以只需考虑接受的这一个布尔张量
中为True元素的行列索引。
当torch.where(bool_tensor)函数接受的布尔张量均为一维张量,则最终返回的元组只有一个非空张量,对应该一维布尔张量为True元素的(列)索引。对a2,为True的元素(列)索引
为0, 2 因此c2 = torch.where(a2): (tensor([0, 2]),)

3.torch.where(condition, x, y)返回张量x, y中部分元素组成的张量,shape为张量x.shape和y.shape广播之后的shape,这里记为shape*
[注意:x.shape与y.shape可以不同,但必须满足广播规则。此外,标量和张量也可以进行广播。]

等价于如下伪代码(不严谨,考虑的是广播之后是二维张量情形):
for i in range(0, shape*[0]): 
  for j in range(0, shape*[1]):
    if condition:
      return x_ij
    else:
      return y_ij

--例3:
x = torch.tensor([[1, 3, 5]])  # (3, )
y = torch.tensor([[2], [4], [6]])  # (3, 1)
z = torch.where(x > 2, x, y)  # broadcasting: (3, 3)

print(z)
#tensor([[2, 3, 5],
#        [4, 3, 5],
#        [6, 3, 5]])
print(z.shape)  # torch.Size([3, 3])

--例4:
condition1 = torch.tensor([False, True, True, False])

x1 = torch.tensor([12, 22, 32, 42]) # (4, )
y1 = torch.tensor([1, 2, 3, 4])  # (4, )

z1 = torch.where(condition1, x1, y1)
print(z1)  # tensor([ 1, 22, 32,  4])
print(z1.shape)  # torch.Size([4])
"""
参考博客:【Torch API】pytorch中的torch.where()函数详解-CSDN博客

  • 18
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值