题主的问题其实是没有弄明白transformer中的Attention机制,以及attention机制的作用点而导致的。那接下来我们介绍一下transformer中的注意力机制,题主自然就明白了。
Multi-Headed Attention
transformer采用的是多头注意力机制。每个输入被分割成了多个头,允许网络注意每个输入的不同子部分。多头注意力网络结构详情如图图1 多头注意力网络结构
V、 K和Q代表“key”、“value”和“query”。这些是注意功能中使用的术语,但老实说,我不认为解释这些术语对于理解模型特别重要。这个网络结构在transformer中会使用三次,encoder,decoder,以及将encoder和decoder信息联合起来。下一节会详细介绍。
举个栗子,在Encoder中,V、K和G是相同的,都是输入向量的映射,尺寸为:batch sizesequence lengthembedding size。然后在多头注意力机制中,将输入向量分成N个头,这样它们就有了尺寸batch_sizeNsequence_length*(embedding size/N)。
对于图一中的scaled dot-product attention的具体计算,是针对每个头,都会进行如图二中的计算。即将Q,K相乘之后,经过变换,再与V相乘,最终得到每个头的attention之后的值。最后,再将每个头的值连接起来,得到最终的attention值。图2 注意力机制的计算
以上就是transformer运用的多头注意力机制的细节。