概述
transformer中的mask操作,可以分成encoder端和decoder端
- Enocoder中的mask
首先确定mask的形状是 batch_size*seq_length
对于小于最大长度的句子进行补0操作,对于大于最大长度的句子进行截断操作。
虽然对少于最大长度的句子进行了补零操作但是这些0仍然会参与注意力分数的计算。
这里需要将mask中的0变成负无穷,1变成0,与计算的注意力矩阵相加,原来有单词的注意力不变,没有单词的位置变换成负无穷,之后在进行softmax运算 - decoder端的mask
decoder的mask形状是一个下三角矩阵,解码器在翻译单词时只能看到前面已经翻译的单词,不能看到后面的答案,所以使用一个下三角矩阵进行遮盖,每一次解码只给解码器看前面的单词。