torch.where()函数详解

torch.where()函数详解

最近查了一些关于 torch.where() 函数的相关解释,有的高赞说torch.where()函数的作用是按照一定的规则合并两个tensor类型。其实合并这个词用的很不恰当,解释的很牵强。通过查找官方文档,整理一下对这个函数进行解释。

torch.where(condition, x, y) → Tensor
1. 函数解释

根据条件(condition),返回一个从 x 或 y 中选择的元素,即给定了条件限制,如果满足条件,则返回 x 的元素,否则返回 y 中的元素。

o u t i = { x i = i f   c o n d i t i o n i y i = o t h e r w i s e out_i=\left\{ \begin{aligned} x_i & = if \ condition_i \\ y_i & = otherwise \\ \end{aligned} \right. outi={xiyi=if conditioni=otherwise
注意:condition,x,y 必须是可广播的(broadcasting)。

2. 参数

condition (BoolTensor) - 当为 True 时,返回 x,否则返回 y

x (Tensor or Scalar) - 在条件为 True 的索引处返回的元素

y (Tensor or Scalar) - 在条件为 False 时在索引处返回的元素

3. 返回值

返回一个 tensor,tensor 的维度和广播后的 x 或 y 相同。

4. 举例说明

例子1:

import torch
x = torch.randn(3, 2)
y = torch.ones(3, 2)
a = torch.where(x > 0, x, y)
print(x)
print("shape of x = ",x.shape)
print(y)
print("shape of y = ",y.shape)
print(a)
print("shape of a = ",a.shape)

输出:

tensor([[ 0.3870, -0.1952],
        [-0.6414, -1.9798],
        [-0.0133,  0.9991]])
shape of x =  torch.Size([3, 2])
tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])
shape of y =  torch.Size([3, 2])
tensor([[0.3870, 1.0000],
        [1.0000, 1.0000],
        [1.0000, 0.9991]])
shape of a =  torch.Size([3, 2])

例子2:

import torch
x = torch.randn(2, 2, dtype=torch.double)
c = torch.where(x > 0, x, 0.)
print(x)
print("shape of x = ",x.shape)
print(c)
print("shape of c = ",c.shape)

输出:

tensor([[ 0.1419,  0.3472],
        [ 0.0400, -0.9591]], dtype=torch.float64)
shape of x =  torch.Size([2, 2])
tensor([[0.1419, 0.3472],
        [0.0400, 0.0000]], dtype=torch.float64)
shape of c =  torch.Size([2, 2])
torch.where() 函数还有如下形式:
torch.where(condition) → tuple of LongTensor

torch.where(condition)与torch.nonzero(condition, as_tuple=True)是一样的。稍后再更新torch.nonzero(condition, as_tuple=True)的相关说明。

官方文档链接:https://pytorch.org/docs/stable/generated/torch.where.html?highlight=where#torch.where

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值