torch.masked_select()用法
>>> import torch
>>> x=torch.randn(3,2)
>>> mask=x>0
>>> x
tensor([[-1.4701, 0.2248],
[-0.5485, 0.4736],
[ 0.5431, 0.0294]])
>>> mask
tensor([[False, True],
[False, True],
[ True, True]])
>>> torch.masked_select(x,mask)
tensor([0.2248, 0.4736, 0.5431, 0.0294])
Tensor.masked_scatter()用法
>>> import torch
>>> x=torch.randn(3,4)
>>> mask=torch.randn(3,4)>0
>>> re=torch.randn(3,4)+100
>>> x
tensor([[ 1.1707, 1.0754, -0.7924, 0.0886],
[-1.6767, -1.6024, 0.7270, -1.8158],
[-0.1692, -0.3450, 0.8961, 1.0697]])
>>> mask
tensor([[ True, True, True, True],
[ True, False, False, True],
[ True, False, False, True]])
>>> re
tensor([[100.6069, 100.8317, 100.3769, 98.5353],
[ 99.1745, 99.4091, 100.2879, 100.2621],
[ 99.9853, 101.0823, 101.5993, 100.1091]])
>>> x.masked_scatter_(mask,re)
tensor([[100.6069, 100.8317, 100.3769, 98.5353],
[ 99.1745, -1.6024, 0.7270, 99.4091],
[100.2879, -0.3450, 0.8961, 100.2621]])
re根据mask的True位置来赋值x,但要注意,re是按顺序来赋值给x的,re数量小于mask的true数量就会报错。
>>> x=torch.randn([3,2])
>>> mask=x>0
>>> x
tensor([[ 0.4304, -0.4132],
[ 0.0061, 0.1836],
[-0.8996, -0.7297]])
>>> mask
tensor([[ True, False],
[ True, True],
[False, False]])
>>> re=torch.randn([2])
>>> x.masked_scatter_(mask,re)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of elements of source < number of ones in mask
>>> re =torch.randn([3])
>>> re
tensor([-0.5413, -0.1640, 0.1688])
>>> x.masked_scatter_(mask,re)
tensor([[-0.5413, -0.4132],
[-0.1640, 0.1688],
[-0.8996, -0.7297]])