关于transformer中为什么要用mask

作者:王椗
链接:https://www.zhihu.com/question/369075515/answer/994819222
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

我自己的理解是这样的。

Transformer在训练的时候是并行执行的,所以在decoder的第一个sublayer里需要seq mask,其目的就是为了在预测未来数据时把这些未来的数据屏蔽掉,防止数据泄露。如果我们非要去串行执行training,seq mask其实就不需要了。比如说我们用transformer做NMT,训练数据里有一个sample是I love China -->我爱中国。利用串行的思维来想,在训练过程中,我们会

1. 把I love China输入到encoder里去,利用top encoder最终输出的tensor (size: 1X3X512,假设我们采用的embedding长度为512,而且batch size = 1)作为decoder里每一层用到的k和v;

2. 将<s>作为decoder的输入,将decoder最终的输出和‘我’做cross entropy计算error。

3. 将<s>,我作为decoder的输入,将decoder最终:输出的最后一个prob. vector和‘爱’做cross entropy计算error。

4. 将<s>,我,爱 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘中’做cross entropy计算error。

5. 将<s>,我,爱,中 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和‘国’做cross entropy计算error。

6. 将<s>,我,爱,中,国 作为decoder的输入,将decoder最终的输出的最后一个prob. vector和</s>做cross entropy计算error。

2-6里都可以不用seq mask。

而在transformer实际的training过程中,我们是并行地将2-6在一步中完成,即

7:将<s>,我,爱,中,国 作为decoder的输入,将decoder最终输出的5个prob. vector和我,爱,中,国,</s>分别做cross entropy计算error。

比如要想在7中计算第一个prob. vector的整个过程中,都不用到‘我’及其后面字的信息,就必需seq mask。对所有位置的输入,情况都是如此。

但是,仔细想想,7虽然包括了2-6,不过有一点区别。比如对3来说,我们是可以不用seq mask的,这时 <s>所对应的encoder output是会利用'我'里的信息的;而在并行时,seq mask是必需的,这时<s>所对应的encoder output是不会利用'我'里的信息的。

如此一来,我们可以看到,在transformer训练时,由于是并行计算,decoder的第i个输入只能用到i,i-1,..., 0这些位置上输入的信息;当训练完成后,在实际预测过程中,虽然理论上decoder的第i个输入可以用到所有位置上输入的信息,但是由于模型在训练过程中是按照前述方式训练的,所以继续使用seq mask会和训练方式匹配,得到更好的预测结果。

我感觉从理论上看,按照串行方式1-6来训练并且不用seq mask,我们可以把信息用得更足一些,似乎可能模型的效果会好一点,但是计算效率比transformer的并行训练差太多,最终综合来看应该还是并行的综合效果好。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值