pytorch masked_fill方法简单理解

本文详细介绍了PyTorch中的masked_fill方法,该方法用于根据布尔掩码mask将张量self中指定位置的元素替换为特定值value。文章通过实例展示了如何使用此方法,并解释了mask必须与self张量形状相匹配或可广播的要求。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

>>> help(torch.Tensor.masked_fill)
Help on method_descriptor:
masked_fill(...)
    masked_fill(mask, value) -> Tensor
    Out-of-place version of :meth:`torch.Tensor.masked_fill_`
>>> help(torch.Tensor.masked_fill_)
Help on method_descriptor:
masked_fill_(...)
    masked_fill_(mask, value)
    Fills elements of :attr:`self` tensor with :attr:`value` where :attr:`mask` is
    True. The shape of :attr:`mask` must be
    :ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
    tensor.

    Args:
        mask (BoolTensor): the boolean mask
        value (float): the value to fill in with

masked_fill方法有两个参数,maske和value,mask是一个pytorch张量(Tensor),元素是布尔值,value是要填充的值,填充规则是mask中取值为True位置对应于self的相应位置用value填充。

>>> t = torch.randn(3,2)
>>> t
tensor([[-0.9180, -0.4654],
        [ 0.9866, -1.3063],
        [ 1.8359,  1.1607]])
>>> m = torch.randint(0,2,(3,2))
>>> m
tensor([[0, 1],
        [1, 1],
        [1, 0]])
>>> m == 0
tensor([[ True, False],
        [False, False],
        [False,  True]])
>>> t.masked_fill(m == 0, -1e9)
tensor([[-1.0000e+09, -4.6544e-01],
        [ 9.8660e-01, -1.3063e+00],
        [ 1.8359e+00, -1.0000e+09]])

注意:参数m必须与t的size相同或者两者是可广播(broadcasting-semantics)的 如下

>>> m = torch.randint(0, 2, (3, 1))
>>> m
tensor([[0],
        [1],
        [1]])
# m和t是可广播的
>>> t.masked_fill(m == 0, -1e9)
tensor([[-1.0000e+09, -1.0000e+09],
        [ 9.8660e-01, -1.3063e+00],
        [ 1.8359e+00,  1.1607e+00]])

>>> m = torch.randint(0, 2, (3, 3))
>>> m
tensor([[0, 0, 0],
        [1, 1, 0],
        [1, 1, 1]])
# m和t是不可广播的
>>> t.masked_fill(m == 0, -1e9)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

关于广播(broadcasting-semantics),可参考pytorch 广播语义(Broadcasting semantics)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值