torch.where()理解

1. torch.where(condition, x, y) → Tensor

返回一个tensor,其中的元素根据condition从x或y中选取
 out  i = { x i  if   condition  i y i  otherwise  \text { out }_{i}= \begin{cases}\mathrm{x}_{i} & \text { if } \text { condition }_{i} \\ \mathrm{y}_{i} & \text { otherwise }\end{cases}  out i={xiyi if  condition i otherwise 

case 1:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)

2. torch.where(condition) → tuple of LongTensor

这种情况下torch.where完全等同于torch.nonzero(condition, as_tuple=True),即找出condition不等于0的元素的索引,当condition是二维tensor时,返回一个n*2的tensor,condition是三维时,返回n*3的tensor.

case2:

torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True)
(tensor([0, 1, 2, 4]),)
>>> a = torch.randn(4,3)
>>> a
tensor([[-0.1272,  1.0898,  0.1808],
        [-0.6453, -1.5093,  2.1270],
        [ 0.3852,  1.0197, -0.2998],
        [-0.2550,  1.2353, -0.1208]])
>>> b = (a>0.1)
>>> b
tensor([[False,  True,  True],
        [False, False,  True],
        [ True,  True, False],
        [False,  True, False]])
>>> b.nonzero()
tensor([[0, 1],
        [0, 2],
        [1, 2],
        [2, 0],
        [2, 1],
        [3, 1]])
>>> a = torch.randn(2,3,4)
>>> a
tensor([[[ 0.0642, -2.1055, -1.0024,  0.6682],
         [-0.1130, -0.5724,  0.4137, -0.4402],
         [ 1.3719,  0.7411,  1.7062, -0.9893]],
 
        [[-0.1432,  0.4898, -0.4913,  1.1397],
         [-0.0883,  1.1581, -0.4784,  0.0195],
         [ 1.3767,  2.2682, -0.6881,  1.9919]]])
 
>>> b = (a>0.1)
>>> b
tensor([[[False, False, False,  True],
         [False, False,  True, False],
         [ True,  True,  True, False]],
 
        [[False,  True, False,  True],
         [False,  True, False, False],
         [ True,  True, False,  True]]])
>>> b.nonzero()
tensor([[0, 0, 3],
        [0, 1, 2],
        [0, 2, 0],
        [0, 2, 1],
        [0, 2, 2],
        [1, 0, 1],
        [1, 0, 3],
        [1, 1, 1],
        [1, 2, 0],
        [1, 2, 1],
        [1, 2, 3]])
  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值