def generate_sent_masks(batch_size, max_seq_length, source_lengths):
""" Generate sentence masks for encoder hidden states.
returns enc_masks (Tensor): Tensor of sentence masks of shape (b, max_seq_length),where max_seq_length = max source length """
enc_masks = torch.zeros(batch_size, max_seq_length, dtype=torch.float)
for e_id, src_len in enumerate(source_lengths):
enc_masks[e_id, :src_len] = 1
return enc_masks
句子填充到指定长度(mask矩阵生成)
于 2022-09-22 20:32:27 首次发布