Self-Attention自注意力机制
Attention的计算公式如下所示。
![](https://i-blog.csdnimg.cn/blog_migrate/e499f9c3154bb9ed35566f59d33a377c.png)
假设输入数据为x1和x2,首先通过Input Embedding层将其映射到更高的维度上,得到a1和a2。然后分别将a1和a2通过Wq、Wk和Wv三个参数矩阵生成各自对应的qkv,即q1、k1、v1和q2、k2、v2(Wq、Wk和Wv对于所有的的a都是相同的)。
![](https://i-blog.csdnimg.cn/blog_migrate/8b1310d5b0ea6ab0c4ae9ee026a04621.png)
其中qkv的计算公式如下所示,均为矩阵乘法。
![](https://i-blog.csdnimg.cn/blog_migrate/0b775037f44e89489522db9ee31cb948.png)
求解qkv的过程可以并行化计算(a直接进“行”拼接)
![](https://i-blog.csdnimg.cn/blog_migrate/86fca5141e56db1cd70197e7577e8d61.png)
然后进行match操作,即用a1的q1分别和其他ai的ki进行计算(包括a1的q1与自己的k1进行计算),然后再用a2的q2和其他ai的ki进行计算……直到将所有a的q都与其他ai的ki进行了计算。计算公式如下,其中d表示k的维度(本例中k的维度为2,因此d=2)
![](https://i-blog.csdnimg.cn/blog_migrate/e3afee8212335e6396d8401f4b8c01ae.png)
然后分别将每个a得到的所有α进行Soft-max处理,得到新的α,新的α表示的就是v的权重
![](https://i-blog.csdnimg.cn/blog_migrate/2669fe3c986b7b8a1bce0b93dd2a1758.png)
此过程也可以并行化进行(其中q直接进“行”拼接,k转置之后进行“列”拼接)
![](https://i-blog.csdnimg.cn/blog_migrate/8d19963d4c2f5c8dbf9e45321db7ef81.png)
之后让每个ai的所有新α与对应的v相乘,再将结果相加得到bi,即
![](https://i-blog.csdnimg.cn/blog_migrate/318d260481ecc1b27970f1de007735f0.png)
![](https://i-blog.csdnimg.cn/blog_migrate/34b8dcb9c949749cb2279d9f594ccc97.png)
此计算过程也可以并行化进行计算。
![](https://i-blog.csdnimg.cn/blog_migrate/759e26bee5198502d97216016c41ce47.png)
至此,Attention的计算公式完成。于是可以把Self-Attention抽象成一个模块,输入a得到b。
![](https://i-blog.csdnimg.cn/blog_migrate/e7a300d4d95c9f3b397bf4bb59dee51a.png)
Multi-Head Attention多头注意力机制
1个head的情况。
![](https://i-blog.csdnimg.cn/blog_migrate/8fbbd445abb5500744715b2acd651911.png)
对于多个head的情况,求每个输入的qkv的过程和self-Attention是一样的,不同之处是求完每个输入的qkv之后,要根据head的数量对所求得的qkv进行拆分。
![](https://i-blog.csdnimg.cn/blog_migrate/c8c69b36607e07da6f30acd726878eb3.png)
然后再将得到的所有qkv进行划分,分到不同的head下
![](https://i-blog.csdnimg.cn/blog_migrate/d39baea425df1bb275102ec1a379e0c6.png)
![](https://i-blog.csdnimg.cn/blog_migrate/6ef8a1816b8fa47b0ea8d00ea3c86f46.png)
然后对每一个head执行self-Attention中的操作
![](https://i-blog.csdnimg.cn/blog_migrate/99bfc117e02e984e77ce2bca65e16358.png)
然后再将每个head得到的结果进行拼接。
![](https://i-blog.csdnimg.cn/blog_migrate/55219b39fc8440a114395630628ec613.png)
之后通过一个参数矩阵Wo得到最终的b。
![](https://i-blog.csdnimg.cn/blog_migrate/16cfa0af9ecf2434d9d8f14b2ab95b04.png)
![](https://i-blog.csdnimg.cn/blog_migrate/cffd59bc2c2a795bc1a73761774e0634.png)
Position REncoding位置编码
![](https://i-blog.csdnimg.cn/blog_migrate/0aa0c3e0557669e34ef92755451a287d.png)