import torch
import numpy as np
import matplotlib.pyplot as plt
def subsequent_mask(size):
"Mask out subsequent positions."
attn_shape = (1, size, size)
# print(attn_shape)
# print(np.ones(attn_shape))
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
# print(subsequent_mask)
return torch.from_numpy(subsequent_mask) == 0
print(subsequent_mask(5))
# 本来是输出tensor转换成numpy在转换成int
print((subsequent_mask(10)==0).numpy().astype(int))
'''
tensor([[[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]]])
[[[0 1 1 1 1 1 1 1 1 1]
[0 0 1 1 1 1 1 1 1 1]
[0 0 0 1 1 1 1 1 1 1]
[0 0 0 0 1 1 1 1 1 1]
[0 0 0 0 0 1 1 1 1 1]
[0 0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 0 1 1 1]
[0 0 0 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 0 1]
[0 0 0 0 0 0 0 0 0 0]]]
'''