# step6: 构造decoder self-attention的mask
valid_decoder_tri_matrix = torch.cat(
[torch.unsqueeze(F.pad(torch.tril(torch.ones((L, L))), (0, max(tgt_len) - L, 0, max(tgt_len) - L)), 0) \
for L in tgt_len]) # 下三角可以表示给过去值,预测下一个值,因果mask;pad成最大长度
invalid_decoder_tri_matrix = 1 - valid_decoder_tri_matrix
invalid_decoder_tri_matrix = invalid_decoder_tri_matrix.to(torch.bool)
# print(invalid_decoder_tri_matrix)
score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(invalid_decoder_tri_matrix, -1e9)
prob3 = F.softmax(masked_score, -1)
print(tgt_len)
print(prob3)
【Transformer】Decoder self-attention mask
最新推荐文章于 2024-09-02 23:14:47 发布