转载请注明出处
一、为什么要使用attention机制?
使用attention机制的原因在于输入序列携带的信息并不是同等重要的,就像人观察一幅图画的时候注意力其实是集中在几个区域,而其余的地方则会忽略,attention就是模拟这个过程。
二、attention机制怎么实现?
图1 基于RNN结构的attention机制
以在RNN中加入attention机制为例,如上图所示。假设要实现对“机器学习”的翻译,attention机制的做法是给“机器学习”四个字分别进行加权后求和,以更多的保留较为关注的文字的信息。加权方式为使用match模块对输入的所有的字xn与隐藏状态h进行注意力分布运算(也有的文章将这部分叫“打分”),得到注意力值a01 a02 a03 a04 ,之后a01 a02 a03 a04 进入softmax得到注意力分布,然后注意力分布与对应输入的字xn相乘后求和,得到输出c0。如上图所示,“机器”二字的权值为0.5,“学习”二字的权值为0,则本次不关注“学习”二字。得到输出c0后,输入RNN翻译得出“machine”。
Match模块大致有四种打分函数:
s(xi,q)为打分函数。以图1为例,q即为图1中的h(在下文中以Q\K\V进行attention的介绍,此时Q不再是隐藏层状态h,而是x的线性变换),即隐藏状态;x为输入;W与U为参数矩阵,可通过训练得到;根号d为缩放值,据论文研究表明,除以根号d可以使梯度更稳定。
三、attention机制的分类
由以上attention的基本原理介绍,可以知道Attention的本质:寻址(addressing)或者加权求和。
图2 Attention机制结构
如图所示为attention机制的结构框图。图中的Q\K\V均由输入X乘以对应的矩阵W生成,若K=X=V,则为普通模式,若K!=X!=V,则为键值对模式。如下图所示:
图3 Attention机制分类
1)、普通模式
普通模式即为图1所示内容,不再赘述。
2)、键值对模式
键值对模式下注意力函数如下:
其中K\Q\V均为输入X与对应的矩阵相乘得到。
可以看到,随着输入的变化,键值对模型的注意力分布也在动态变化,这即是自注意力模型(self-attention model)。自注意力模型适合处理长度不固定的输入序列。自注意力模型通常使用缩放点积作为打分函数(缩放点积公式上文已给出,这里再描述一次。只是在自注意力模型中将上文公式中的X换成了K)。
要注意的一点是,无论是普通模式还是键值对模式,一个输入向量X生成的Q都需要与自己以及其他输入向量生成的K做运算。如下图所示,Q2与所有的K做运算生成X2的注意力分布,然后与所有输入生成的V加权求和得到X2的最终运算结果。
图4 注意力机制计算举例
三、多头自注意力(multi-head attention)与位置编码
1)、多头自注意力
多头自注意力的结构如图5所示。
图5 多头自注意力机制框图
其使用多个WK WQ WV生成多组K、Q、V,使用多组K、Q、V进行注意力分布运算,得到多个结果,将得到的多个结果拼接后再做矩阵运算得出最终的输出结果,运算过程如图6所示。
图6 多头自注意力运算过程
多头自注意力机制的其中一项作用是可以使网络注意到序列的不同部分,减少信息丢失。
2)、位置标记
可以注意到在上述计算过程中,输入的X的位置信息没有被纳入计算考虑,可以想到如果将K、V位置互换,输出结果也不会有什么不同,而神经网络中位置是非常重要的信息,它代表着输入序列的结构,所以必须将序列位置信息利用起来。可以给输入加上位置信息完成此项功能。如下图所示。
图7 输入序列加上位置信息
位置信息公式如下:
其中,pos代表位置,例如x4的pos就是4。dmodel代表输入向量的维度,假设输入x4是一列4行的向量,那么上述公式算出来的4个值就分别加到4个元素上,其中偶数位置的元素位置信息使用公式PE(pos,2i)得到,奇数位置元素的位置信息使用公式PE(pos,2i+1)得到。
参考资料:
nlp中的Attention注意力机制+Transformer详解 - 知乎 (zhihu.com)
细讲 | Attention Is All You Need (qq.com)