参考下文
博客
def sequence_mask(lengths, max_len=None):
lengths_shape = lengths.shape # torch.size() is a tuple
lengths = lengths.reshape(-1)
batch_size = lengths.numel()
max_len = max_len or int(lengths.max())
lengths_shape += (max_len,)
return (torch.arange(0,max_len,device=lengths.device)
.type_as(lengths)
.unsqueeze(0).expand(batch_size,max_len)
.lt(lengths.unsqueeze(1))).reshape(lengths_shape)