不得了的Transformer模型(二)
下面我们讲解transformer网络中到底是怎么计算的,里面的维度是咋变化的。
一、整体流程来一遍呀
大体流程咱们先走一遍,咱们以encoder输入为例。首先输入的是一个长度为m的序列X = (x1, x2, x3, ……,xm)。在输入到encoder之前,我们需要对输入X进行一顿操作,即添加标志符号在序列前端添加<GO>标志符,在序列末端添加<EOS>标志符号。不同的代码书写符号可能并不相同,但是他们的意义是完全相同的。经过预处理我们得到了输入序列X,也就是下图中的inputs,接下来我们一步一步进行解析(我们先不考虑batch_size)。
图1
第一步,输入序列X(seq_len为m)经过Input Embedding,进行一个查表操作后,X变成了(seq_len, embedding_len)向量矩阵。
第二步,X(seq_len, embedding_len)经过Positional Encoding,即位置编码,使得X对应加上位置信息向量矩阵,由于两者维度相同,所以得到X维度不变,依然是(seq_len, embedding_len)。
第三步,由图1左我们可以看出来,在X向上输入的过程中,其中有一条路直接通到了Add&Norm,这是为了后面的残差连接操作。另一条路分为三叉,分别乘以WQ,WK,WV,变成了我们所熟悉的Q,K,V。这里我们需要知道WQ和WK的维度要相同,WV的维度要和embedding的维度保持一致。经过公式一的运算之后得到的结果与X进行残差连接并且进行层归一化处理,也就是Add&Norm操作。经过6层操作,最终的输入到decoder的向量矩阵维度为(seq_len, embedding),然后带入到decoder中进行运算,运算过程大致相同。
公式一
二、Multi-Head Attention
我们在这里详细的讲解一下多头注意力是如何拆开和合并的。
首先我们的输入Q(batch_size, seq_len, dk), K(batch_size, seq_len, dk), V(batch_size, seq_len, dv)。由于模型是8头注意力,所以我们对Q,K,V进行处理得到Q(batch_size, seq_len, 8, dk//8), K(batch_size, seq_len,8, dk//8), V(batch_size, seq_len, 8,dv//8)。然后咱们根据公式一进行展示。
首先是Q:我们这里使用tf.matmul操作,即维度操作为(batch_size, seq_len, 8, dk//8)* (batch_size, seq_len, dk//8, 8),这样dk//8抵消,得到的结果为(batch_size, seq_len, 8, 8)
然后对(batch_size, seq_len, 8, 8)乘以一个因子操作并进行softmax,此过程并不改变维度。
最终点乘V,同样使用了tf.matmul操作。即(batch_size, seq_len, 8, 8) * (batch_size, seq_len, 8, dv//8),这样8进行抵消,得到的结果为(batch_size, seq_len, 8, dv//8)。
操作结束后我们进行了Concat操作,这样(batch_size, seq_len, 8, dv//8)最终变成了(batch_size, seq_len, dv)。这里的dv就是embedding_size。