torch.masked_select()和Tensor.masked_scatter()的用法

本文详细介绍了如何在PyTorch中使用masked_select函数选择满足条件的张量元素,并展示了masked_scatter函数如何根据mask更新张量值,特别强调了mask长度与源数据长度匹配的重要性。
摘要由CSDN通过智能技术生成

 torch.masked_select()用法

>>> import torch
>>> x=torch.randn(3,2)
>>> mask=x>0
>>> x
tensor([[-1.4701,  0.2248],
        [-0.5485,  0.4736],
        [ 0.5431,  0.0294]])
>>> mask
tensor([[False,  True],
        [False,  True],
        [ True,  True]])

>>> torch.masked_select(x,mask)
tensor([0.2248, 0.4736, 0.5431, 0.0294])

Tensor.masked_scatter()用法

>>> import torch
>>> x=torch.randn(3,4)
>>> mask=torch.randn(3,4)>0
>>> re=torch.randn(3,4)+100

>>> x
tensor([[ 1.1707,  1.0754, -0.7924,  0.0886],
        [-1.6767, -1.6024,  0.7270, -1.8158],
        [-0.1692, -0.3450,  0.8961,  1.0697]])
>>> mask
tensor([[ True,  True,  True,  True],
        [ True, False, False,  True],
        [ True, False, False,  True]])
>>> re
tensor([[100.6069, 100.8317, 100.3769,  98.5353],
        [ 99.1745,  99.4091, 100.2879, 100.2621],
        [ 99.9853, 101.0823, 101.5993, 100.1091]])

>>> x.masked_scatter_(mask,re)
tensor([[100.6069, 100.8317, 100.3769,  98.5353],
        [ 99.1745,  -1.6024,   0.7270,  99.4091],
        [100.2879,  -0.3450,   0.8961, 100.2621]])

re根据mask的True位置来赋值x,但要注意,re是按顺序来赋值给x的,re数量小于mask的true数量就会报错。

>>> x=torch.randn([3,2])
>>> mask=x>0
>>> x
tensor([[ 0.4304, -0.4132],
        [ 0.0061,  0.1836],
        [-0.8996, -0.7297]])
>>> mask
tensor([[ True, False],
        [ True,  True],
        [False, False]])

>>> re=torch.randn([2])
>>> x.masked_scatter_(mask,re)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of elements of source < number of ones in mask

>>> re =torch.randn([3])
>>> re
tensor([-0.5413, -0.1640,  0.1688])
>>> x.masked_scatter_(mask,re)
tensor([[-0.5413, -0.4132],
        [-0.1640,  0.1688],
        [-0.8996, -0.7297]])

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值