torch.where()函数详解
最近查了一些关于 torch.where() 函数的相关解释,有的高赞说torch.where()函数的作用是按照一定的规则合并两个tensor类型。其实合并这个词用的很不恰当,解释的很牵强。通过查找官方文档,整理一下对这个函数进行解释。
torch.where(condition, x, y) → Tensor
1. 函数解释
根据条件(condition),返回一个从 x 或 y 中选择的元素,即给定了条件限制,如果满足条件,则返回 x 的元素,否则返回 y 中的元素。
o
u
t
i
=
{
x
i
=
i
f
c
o
n
d
i
t
i
o
n
i
y
i
=
o
t
h
e
r
w
i
s
e
out_i=\left\{ \begin{aligned} x_i & = if \ condition_i \\ y_i & = otherwise \\ \end{aligned} \right.
outi={xiyi=if conditioni=otherwise
注意:condition,x,y 必须是可广播的(broadcasting)。
2. 参数
condition (BoolTensor) - 当为 True 时,返回 x,否则返回 y
x (Tensor or Scalar) - 在条件为 True 的索引处返回的元素
y (Tensor or Scalar) - 在条件为 False 时在索引处返回的元素
3. 返回值
返回一个 tensor,tensor 的维度和广播后的 x 或 y 相同。
4. 举例说明
例子1:
import torch
x = torch.randn(3, 2)
y = torch.ones(3, 2)
a = torch.where(x > 0, x, y)
print(x)
print("shape of x = ",x.shape)
print(y)
print("shape of y = ",y.shape)
print(a)
print("shape of a = ",a.shape)
输出:
tensor([[ 0.3870, -0.1952],
[-0.6414, -1.9798],
[-0.0133, 0.9991]])
shape of x = torch.Size([3, 2])
tensor([[1., 1.],
[1., 1.],
[1., 1.]])
shape of y = torch.Size([3, 2])
tensor([[0.3870, 1.0000],
[1.0000, 1.0000],
[1.0000, 0.9991]])
shape of a = torch.Size([3, 2])
例子2:
import torch
x = torch.randn(2, 2, dtype=torch.double)
c = torch.where(x > 0, x, 0.)
print(x)
print("shape of x = ",x.shape)
print(c)
print("shape of c = ",c.shape)
输出:
tensor([[ 0.1419, 0.3472],
[ 0.0400, -0.9591]], dtype=torch.float64)
shape of x = torch.Size([2, 2])
tensor([[0.1419, 0.3472],
[0.0400, 0.0000]], dtype=torch.float64)
shape of c = torch.Size([2, 2])
torch.where() 函数还有如下形式:
torch.where(condition) → tuple of LongTensor
torch.where(condition)与torch.nonzero(condition, as_tuple=True)是一样的。稍后再更新torch.nonzero(condition, as_tuple=True)的相关说明。
官方文档链接:https://pytorch.org/docs/stable/generated/torch.where.html?highlight=where#torch.where