今天我想给大家介绍这样一篇论文:Multi-Head Attention: Collaborate Instead of Concatenate。作者均来自
洛桑联邦理工学院_百度百科baike.baidu.com看过我文章的同学肯定知道,我一直在关注bert模型的性能优化相关研究,而这篇论文正好是与transformer的性能优化相关,并且我认为它的方法不需要做太多的适配就能应用在预训练模型上面,实用性较高,因此推荐给大家。
众所周知,经典的transformer架构中采用了multi-head attention机制来引导模型从不同角度学习不同的语义信息,从各种实验对比中也能发现多头机制确实能够提升模型在NLP任务上的精度。然而,随着目前大规模预训练模型的普及,多头注意力机制在带来精度提升的同时,也增加了计算的成本,带来了性能上的限制。
因此最近两年,有些研究人员尝试从不同的维度去探讨是否能从多头机制上去优化transformer的性能。有些工作重点关注了多头中每个头的注意力到底捕捉了哪些语义信息,头与头之间捕捉的信息是否有冗余,例如这篇论文:Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned,提出了一种量化注意力头重要程度的方法。还有一些工作更加激进,提出了多头注意力机制是否有必要的疑问,例如这篇论文:Are sixteen heads really better than one。它对transformer中的每个头都做了消融实验,探讨了每个头在不同下游NLP任务上的作用,最后提出了一种迭代式地剪枝注意力头的方法。
与上述工作不同,本篇论文并非直接对注意力头进行结构性剪枝,而是关注所有注意力头捕捉的通用信息,试图将这些信息提取出来作为sharing weights,每个头各自关注自己独有的工作,从而减少多头注意力计算时的成本。下面我就详细得为大家解读这篇论文的工作。
单个注意力头的减负
在那篇经典的Attention is all you need论文中,对于注意力分数的计算是这样的:
其中,
然而,在各种版本的transformer实现中,上述各种线性映射计算是附加bias的,即
在引入了bias后,我们重新对
备注一下:论文这里的公式貌似有点问题,最后一项应该是我推导出的项。
最后两项在做softmax的时候可以舍弃掉,为什么呢?其实很简单,我们得到的Attention分数是一个T*T的矩阵,而
另外,对于上述推导式的第一项,由于其计算了Query和key的相互关系,因此相当于捕捉了上下文的相关信息,而第二项只包含了key的content信息,相当于捕捉了原文内容上的信息。
多头注意力的整合
传统的transformer中,对于不同的注意力采取的整合方式是直接拼接,如下所示: