需要注意的是当掩码是1(true)的时候,是要掩住的内容,当掩码为0(false)的时候,是不需要掩住的内容
当然这是正常操作,还有不正常操作,比如当掩码为1的时候,但真正用的时候取个反,虽然原理还是一样,但是如果从掩码这个地方看的话就是正好反着了
我的主要写的内容取自这篇文章https://www.cnblogs.com/YuanShiRenY/p/torch_mask.html
但是感觉他有的地方讲反了,所以重新整理了一下
torch中的mask主要分为三种:masked_fill, masked_select, masked_scatter
masked_fill:
fill是填满,填补的意思,masked_fill就是对mask掉的内容填补成自己想要的数
输入:
import torch
x=torch.randint(200,[3,4])
print(x)
mask=torch.randint(2,[3,4])
print(mask)
torch.masked_fill(input=x,mask=mask,value=-1)
输出:
tensor([[ 99, 98, 91, 74],
[ 39, 104, 48, 149],
[196, 184, 76, 41]])
tensor([[0, 0, 0, 0],
[0, 0, 1, 1],
[0, 1, 0, 1]])
tensor([[ 99, 98, 91, 74],
[ 39, 104, -1, -1],
[196, -1, 76, -1]])
最后一行也可以写成
torch.masked_fill(x,mask,-1)
masked_select:
select是选择的意思,那么masked_select就是将没有mask掉的内容输出
输入:
import torch
x=torch.randint(200,[3,4])
print(x)
mask=torch.randint(2,[3,4]).bool()
print(mask)
torch.masked_select(x,mask)
输出:
tensor([[ 57, 146, 129, 138],
[ 20, 77, 54, 34],
[ 4, 48, 137, 110]])
tensor([[ True, False, True, False],
[False, False, True, False],
[False, True, False, False]])
tensor([ 57, 129, 54, 48])
masked_scatter:
对于mask掉的内容,根据source从头开始填充
输入:
import torch
x=torch.randint(200,[3,4])
print(x)
mask=torch.randint(2,[3,4]).bool()
print(mask)
# source = torch.zeros_like(x)
source = torch.randint(10,[3,4])
print(source)
torch.masked_scatter(input=x, mask=mask, source=source)
输出:
tensor([[116, 128, 181, 129],
[ 11, 16, 65, 167],
[ 58, 56, 38, 53]])
tensor([[False, True, True, True],
[False, False, True, False],
[False, False, False, False]])
tensor([[5, 6, 7, 5],
[6, 8, 9, 8],
[8, 1, 5, 9]])
tensor([[116, 5, 6, 7],
[ 11, 16, 5, 167],
[ 58, 56, 38, 53]])