def test_masked_fill():
input = Variable(torch.randn(5, 5))
# print(f'scores-->{scores}')
# print(f'input:{input}')
print(f"input:{input}")
"""
input:tensor([[-0.2357, 0.8143, 0.4173, -0.6915, 0.2506],
[-0.5638, 0.9063, -0.6678, 0.4305, 0.8954],
[ 3.2776, -0.4731, 0.2869, -0.1473, 0.6105],
[ 0.0249, 0.7301, -0.3991, 1.1885, 0.5604],
[-0.5454, 0.5988, -0.1047, -0.2669, 0.6869]])
"""
mask = Variable(torch.zeros(5, 5))
print(f"mask:{mask}")
"""
mask:tensor([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
"""
masked = input.masked_fill(mask == 0, -1e9)
# 注:如果mask==0,则用一个很小很小的值去替换,这里用 -1e9
print(f"masked:{masked}")
"""
masked:tensor([[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
[-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]])
"""
03-02
80
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
07-23
758
![](https://csdnimg.cn/release/blogv2/dist/pc/img/readCountWhite.png)
09-14