import torch
# torch.where()函数的三种用法:
"""
1.torch.where(bool_tensor1 & bool_tensor2)接受两个元素为布尔值(True/False)组成的张量,返回两个张量组成的元组(turple),其中第一个张量中的元素是两布尔张量
bool_tensor1 和 bool_tensor2 均为True的元素的行索引,第二个张量中的元素是两布尔张量bool_tensor1 和 bool_tensor2均为True的元素的列索引。
[注意:两个布尔张量的shape必须满足广播规则,否则会引发shape在某个维度上不match的问题。]
--例1:
a1 = torch.tensor([[False, True, True],
[ True, False, False],
[False, False, True]])
b1 = torch.tensor([[True, True, True],
[True, True, True],
[True, True, True]])
c1 = torch.where(a1 & b1) # 注意两个布尔张量用&连接,而非,
print(c1) # (tensor([0, 0, 1, 2]), tensor([1, 2, 0, 2]))
b1是元素全为True的布尔张量,所以只需考虑a1中为True的元素的行列索引。a1中分别有4个为True的元素,分别是第一行(索引:0)的第二列(索引:1)第三列(索引:2),第二行(索引:1)
第一列(索引0)和第三行(索引:2)第三列(索引:2),写出两个布尔张量a1,a2都为True所有行列索引对(0, 1), (0, 2), (1, 0), (2, 2)。去除第0维度即行方向的索引值组成元组
c1的第一个张量即tensor([0, 0, 1, 2]),取出列方向的索引值组成元组c1的第一个张量即tensor([1, 2, 0, 2]))
--例2:(2.只接受一个布尔张量的特殊情形)
a2 = torch.tensor([True, False, True, False])
print(a2.shape) # torch.Size([4])
c2 = torch.where(a2)
print(c2) # (tensor([0, 2]),)
print(len(c2)) # 1
c3 = torch.where(b1)
print(b1.shape) # torch.Size([3, 3])
print(c3) # (tensor([0, 0, 0, 1, 1, 1, 2, 2, 2]), tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]))
print(len(c3)) # 2
当torch.where(bool_tensor)函数只接受一个布尔张量,默认第二个张量为全True张量(torch.ones(bool_tensor, dtype=torch.bool)),所以只需考虑接受的这一个布尔张量
中为True元素的行列索引。
当torch.where(bool_tensor)函数接受的布尔张量均为一维张量,则最终返回的元组只有一个非空张量,对应该一维布尔张量为True元素的(列)索引。对a2,为True的元素(列)索引
为0, 2 因此c2 = torch.where(a2): (tensor([0, 2]),)
3.torch.where(condition, x, y)返回张量x, y中部分元素组成的张量,shape为张量x.shape和y.shape广播之后的shape,这里记为shape*
[注意:x.shape与y.shape可以不同,但必须满足广播规则。此外,标量和张量也可以进行广播。]
等价于如下伪代码(不严谨,考虑的是广播之后是二维张量情形):
for i in range(0, shape*[0]):
for j in range(0, shape*[1]):
if condition:
return x_ij
else:
return y_ij
--例3:
x = torch.tensor([[1, 3, 5]]) # (3, )
y = torch.tensor([[2], [4], [6]]) # (3, 1)
z = torch.where(x > 2, x, y) # broadcasting: (3, 3)
print(z)
#tensor([[2, 3, 5],
# [4, 3, 5],
# [6, 3, 5]])
print(z.shape) # torch.Size([3, 3])
--例4:
condition1 = torch.tensor([False, True, True, False])
x1 = torch.tensor([12, 22, 32, 42]) # (4, )
y1 = torch.tensor([1, 2, 3, 4]) # (4, )
z1 = torch.where(condition1, x1, y1)
print(z1) # tensor([ 1, 22, 32, 4])
print(z1.shape) # torch.Size([4])
"""
参考博客:【Torch API】pytorch中的torch.where()函数详解-CSDN博客