pytorch掩码(masked)
pytorch使用tensor.masked_fill
将张量中的一些值掩盖掉。
在Transformer中与BERT中都有用到。
方法原型:tensor.masked_fill(mask, value)
- 将mask中为1的部分使用value替代(value通常是一个极大或极小值),0的部分保持原值。
- mask必须是一个
ByteTensor
类型的张量(由01组成) - value是替代值,一般为:
1e9/1e-9
示例:
import torch
x = torch.arange(24).reshape(3, 2, 4)
print(x.shape)
masked = torch.ByteTensor([[[1, 0, 1, 0], [0, 1, 1, 1]],
[[1, 0, 1, 0], [0, 1, 1, 1]],
[[1, 0, 1, 0], [0, 1, 1, 1]]])
print(masked.shape)
x = x.masked_fill(masked, 1e9)
print(x)
print("======================================================")
x = torch.arange(24).reshape(3, 2, 4)
print(x.shape)
masked = torch.ByteTensor([[[1], [0]],
[[1], [0]],
[[1], [0]]])
print(masked.shape)
x = x.masked_fill(masked, 1e9)
print(x)
result:
torch.Size([3, 2, 4])
torch.Size([3, 2, 4])
tensor([[[1000000000, 1, 1000000000, 3],
[ 4, 1000000000, 1000000000, 1000000000]],
[[1000000000, 9, 1000000000, 11],
[ 12, 1000000000, 1000000000, 1000000000]],
[[1000000000, 17, 1000000000, 19],
[ 20, 1000000000, 1000000000, 1000000000]]])
======================================================
torch.Size([3, 2, 4])
torch.Size([3, 2, 1])
tensor([[[1000000000, 1000000000, 1000000000, 1000000000],
[ 4, 5, 6, 7]],
[[1000000000, 1000000000, 1000000000, 1000000000],
[ 12, 13, 14, 15]],
[[1000000000, 1000000000, 1000000000, 1000000000],
[ 20, 21, 22, 23]]])
第一部分是按照原始形状进行掩码,对应位置进行替换。第二部分是进行广播后再进行掩码,所以倒数第二个维度中的所有元素都进行了mask。