第35个方法
torch.masked_select(input, mask, *, out=None) → Tensor
此方法中mask是一个bool矩阵,在input中取出mask中True对应的值。
首先介绍参数:
input(tensor)
:需要进行处理的tensor。mask(BoolTensor)
:包含了二进制掩码,要进行索引的tensor。out
:输出的结果tensor(结果为一维tensor)。
使用方法如下:
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297, 0.3477],
[-1.2035, 1.2252, 0.5002, 0.6248],
[ 0.1307, -2.0608, 0.1244, 2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],
[False, True, True, True],
[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252, 0.5002, 0.6248, 2.0139])
使用方法其实挺明显的,就是把input与mask相对应起来,取出mask中True所对应位置的数据,组成一维的tensor。
注意:mask和input的形状可以不相同,但是它们必须是可以广播的。并且返回tensor和原tensor使用不同的内存,相互独立。