>>> help(torch.Tensor.masked_fill)
Help on method_descriptor:
masked_fill(...)
masked_fill(mask, value) -> Tensor
Out-of-place version of :meth:`torch.Tensor.masked_fill_`
>>> help(torch.Tensor.masked_fill_)
Help on method_descriptor:
masked_fill_(...)
masked_fill_(mask, value)
Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is
True. The shape of :attr:`mask` must be
:ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
tensor.
Args:
mask (BoolTensor): the boolean mask
value (float): the value to fill in with
masked_fill方法有两个参数,maske和value,mask是一个pytorch张量(Tensor),元素是布尔值,value是要填充的值,填充规则是mask中取值为True位置对应于self的相应位置用value填充。
>>> t = torch.randn(3,2)
>>> t
tensor([[-0.9180, -0.4654],
[ 0.9866, -1.3063],
[ 1.8359, 1.1607]])
>>> m = torch.randint(0,2,(3,2))
>>> m
tensor([[0, 1],
[1, 1],
[1, 0]])
>>> m == 0
tensor([[ True, False],
[False, False],
[False, True]])
>>> t.masked_fill(m == 0, -1e9)
tensor([[-1.0000e+09, -4.6544e-01],
[ 9.8660e-01, -1.3063e+00],
[ 1.8359e+00, -1.0000e+09]])
注意:参数m必须与t的size相同或者两者是可广播(broadcasting-semantics)的 如下
>>> m = torch.randint(0, 2, (3, 1))
>>> m
tensor([[0],
[1],
[1]])
# m和t是可广播的
>>> t.masked_fill(m == 0, -1e9)
tensor([[-1.0000e+09, -1.0000e+09],
[ 9.8660e-01, -1.3063e+00],
[ 1.8359e+00, 1.1607e+00]])
>>> m = torch.randint(0, 2, (3, 3))
>>> m
tensor([[0, 0, 0],
[1, 1, 0],
[1, 1, 1]])
# m和t是不可广播的
>>> t.masked_fill(m == 0, -1e9)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
关于广播(broadcasting-semantics),可参考pytorch 广播语义(Broadcasting semantics)