题主的问题其实是没有弄明白transformer中的Attention机制,以及attention机制的作用点而导致的。那接下来我们介绍一下transformer中的注意力机制,题主自然就明白了。
Multi-Headed Attention
transformer采用的是多头注意力机制。每个输入被分割成了多个头,允许网络注意每个输入的不同子部分。多头注意力网络结构详情如图图1 多头注意力网络结构
V、 K和Q代表“key”、“value”和“query”。这些是注意功能中使用的术语,但老实说,我不认为解释这些术语对于理解模型特别重要。这个网络结构在transformer中会使用三次,encoder,decoder,以及将encoder和decoder信息联合起来。下一节会详细介绍。
举个栗子,在Encoder中,V、K和G是相同的,都是输入向量的映射,尺寸为:batch sizesequence lengthembedding size。然后在多头注意力机制中,将输入向量分成N个头,这样它们就有了尺寸batch_sizeNsequence_length*(embedding size/N)。
对于图一中的scaled dot-product attention的具体计算,是针对每个头,都会进行如图二中的计算。即将Q,K相乘之后,经过变换,再与V相乘,最终得到每个头的attention之后的值。最后,再将每个头的值连接起来,得到最终的attention值。图2 注意力机制的计算
以上就是transformer运用的多头注意力机制的细节。
multi-head attention的具体作用点图3
如上图所示,是transformer的整体结构。可以清晰地看见,多头注意力机制的三个作用点。 其中,在Encoder中多头注意力的Q、K、V输入都是一样的,就是整个encoder的输入;在decoder中,输入也是一样的,即为上一个step的输出的映射值;但是在最后运用decoder去查询encoder信息时,发生了不一样的故事,此时的K和V都为encoder的输出,但是Q为decoder的输出。
其实从上面这个介绍就可以很清晰的看出来,在每一个step,decoder会去查询encoder中的信息,从而得到一个output,即为每一个step输出的概率分布;然后再将这个概率分布作为下一个step的decoder的输入,重复以上操作,得到新的概率分布,直到序列结束。
这个过程就是题主想说的RNN的模拟过程吧~