pytorch每日一学49(torch.where())根据t指定条件更改指定tensor中的数值

第49个方法

torch.where(condition, x, y) → Tensor

此方法是将x中的元素和条件相比,如果符合条件就还等于原来的元素,如果不满足条件的话,那就去y中对应的值,公式为
在这里插入图片描述
举个例子就清楚了,官方例子如下:

>>> x = torch.randn(3, 2)
>>> y = torch.ones(3, 2)
>>> x
tensor([[-0.4620,  0.3139],
        [ 0.3898, -0.7197],
        [ 0.0478, -0.1657]])
>>> torch.where(x > 0, x, y)
tensor([[ 1.0000,  0.3139],
        [ 0.3898,  1.0000],
        [ 0.0478,  1.0000]])
>>> x = torch.randn(2, 2, dtype=torch.double)
>>> x
tensor([[ 1.0779,  0.0383],
        [-0.8785, -1.1089]], dtype=torch.float64)
>>> torch.where(x > 0, x, 0.)
tensor([[1.0779, 0.0383],
        [0.0000, 0.0000]], dtype=torch.float64)

其中第一个例子,取x大于零,那么x中小于0的值就会变成y中与之相对应位置的值,由于y里面都是1,所以x中小于0的位置都变成了1。

第二个例子条件仍然为x>0,但是这时的y是0,那么x中小于0的值就都变成了0。

其实是有对应位置关系的,x中元素不满足条件时,取得是y中对应的元素来填充里面的值。如下所示:
在这里插入图片描述
a中第一个小于零的值为-1,被b中对应的值3直接填补,而-2被7填补,所以填补是有对应关系的,所以x和y必须是可广播的,否则就会报错。

标量和张量也可以组合,就如同官方的第二个示例,其中x就是张量,而y就是标量。目前有效地标量和张量的组合为:

  • float型和double型
  • integral型和long型
  • complex型和complex 128型

而torch.where(condition)与torch.nonzero(condition, as tuple=True)相等,就是返回了其中满足条件的元素的位置,不过是以元组的形式返回的,第一个元组代表了第一个维度的位置,第二个代表第二维,一次类推,如下所示:

在这里插入图片描述
其中取了a中大于0的元素,返回了它们的位置,取第一个里面的为第一维,第二个里面的为第二维,即位置为(0, 0), (0, 3)…,详情请看torch.nonzero()

  • 11
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值