pytorch文档
https://pytorch.org/docs/stable/tensors.html?highlight=masked_fill#torch.Tensor.masked_fill_
in-place
,即原地操作符,直接在原来的内存上改变变量的值,可以称为原地操作符(以masked_fill_()和masked_fill()为例说明)
注:当前浅粉色显示设置方法为,在句子的两端添加" `` ",即键盘左上角波浪线下的符号
函数说明-masked_fill_(mask, value)
- Fills elements of self tensor with value
where mask is True
. The shape of mask must be broadcastable with the shape of the underlying tensor.(当mask值为True时候,用value进行填充,tensor和mask的形状需要保持一致)
举例说明
# 例1
import torch
data = torch.tensor([[1,0,0],[1,1,0],[1,1,1]])
mask = torch.tensor([[False,True,True],[False,False,True],[False,False,False]])
new_data = data.masked_fill(mask,-100)
print(data)
# print(new_data)
'''
data--没改变原来的值
tensor([[1, 0, 0],
[1, 1, 0],
[1, 1, 1]])
'''
new_data = data.masked_fill_(mask,-100)
print(data)
# print(new_data)
'''
data--改变了原来的值
tensor([[ 1, -100, -100],
[ 1, 1, -100],
[ 1, 1, 1]])
'''
# 例二
# mask[:,None,None],会使得mask升维度,[3,3]-->[3,1,1,3]
import torch
data = torch.tensor([[1,0,0],[1,1,0],[1,1,1]])
mask = torch.tensor([[1,0,0],[1,1,0],[1,1,1]])
# mask=mask[:,None,None].eq(0)
# print(mask.size())
new_data = data.masked_fill_(mask=mask.eq(0),value=-100)
print(data)
# print(new_data)