作者:王椗
链接:https://www.zhihu.com/question/369075515/answer/994819222
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
我自己的理解是这样的。
Transformer在训练的时候是并行执行的,所以在decoder的第一个sublayer里需要seq mask,其目的就是为了在预测未来数据时把这些未来的数据屏蔽掉,防止数据泄露。如果我们非要去串行执行training,seq mask其实就不需要了。比如说我们用transformer做NMT,训练数据里有一个sample是I love China -->我爱中国。利用串行的思维来想,在训练过程中,我们会
1. 把I love China输入到encoder里去,利用top encoder最终输出的tensor (size: 1X3X512,假设我们采用的embedding长度为512,而且batch size = 1)作为decoder里每一层用到的k和v;
2. 将<s>作为decoder的输入,将decoder最终的输出和‘我’做cross entropy计算error。
3. 将<s>,我作为decoder的输入,将decoder最终:输出的最后一个prob. vector和‘爱’做cross entropy计算error。
4. 将<s>,我,爱 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘中’做cross entropy计算error。
5. 将<s>,我,爱,中 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘国’做cross entropy计算error。
6. 将<s>,我,爱,中,国 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和</s>做cross entropy计算error。
2-6里都可以不用seq mask。
而在transformer实际的training过程中,我们是并行地将2-6在一步中完成,即
7:将<s>,我,爱,中,国 作为decoder的输入,将decoder最终输出的5个prob. vector和我,爱,中,国,</s>分别做cross entropy计算error。
比如要想在7中计算第一个prob. vector的整个过程中,都不用到‘我’及其后面字的信息,就必需seq mask。对所有位置的输入,情况都是如此。
但是,仔细想想,7虽然包括了2-6,不过有一点区别。比如对3来说,我们是可以不用seq mask的,这时 <s>所对应的encoder output是会利用'我'里的信息的;而在并行时,seq mask是必需的,这时<s>所对应的encoder output是不会利用'我'里的信息的。
如此一来,我们可以看到,在transformer训练时,由于是并行计算,decoder的第i个输入只能用到i,i-1,..., 0这些位置上输入的信息;当训练完成后,在实际预测过程中,虽然理论上decoder的第i个输入可以用到所有位置上输入的信息,但是由于模型在训练过程中是按照前述方式训练的,所以继续使用seq mask会和训练方式匹配,得到更好的预测结果。
我感觉从理论上看,按照串行方式1-6来训练并且不用seq mask,我们可以把信息用得更足一些,似乎可能模型的效果会好一点,但是计算效率比transformer的并行训练差太多,最终综合来看应该还是并行的综合效果好。