Transfomer矩阵维度分析及MultiHead详解


解读Transformer就离不开下面这张图:

不同于之前的基于rnn的seq2seq模型,Transfomer完全摒弃了循环神经网络的结构:

  1. encoder层: {多头自注意力 + 前馈网络} × n \times n ×n
  2. decoder层: {Masked 多头自注意力 + encoder-decoder多头自注意力 + 前馈网络} × n \times n ×n

下面我们介绍Transformer模型中的矩阵维度变化情况:

矩阵维度分析

对于一个batch的数据,encoder端的输入大小为:(batch_size, sr_len);decoder端的输入大小为:(batch_size, tar_len)。不妨假设 encoder layer 及 decoder layer 都只有一层,下面是训练阶段的矩阵维度变化:

训练阶段

训练阶段 Encoder

input_sizeLayeroutput_sizeLayer_parameter_sizeNote
batch_size × \times × sr_lenInput Embeddingbatch_size × \times × sr_len × \times × embed_sizesr_vocab_size × \times × embed_sizeEmbedding层的参数即可设为可学习的,也可设为固定参数
batch_size × \times × sr_len × \times × embed_sizePostion Embeddingbatch_size × \times × sr_len × \times × embed_size1 × \times × sr_len × \times × embed_size固定参数
batch_size × \times × sr_len × \times × embed_sizeMultiHead Attentionbatch_size × \times × sr_len × \times × hidden_size{embed_size × \times × hidden_size} × \times × 3 + {hidden_size × \times × hidden_size}可学习参数
batch_size × \times × sr_len × \times × hidden_sizeAddNorm1batch_size × \times × sr_len × \times × hidden_sizeNone
batch_size × \times × sr_len × \times × hidden_sizeFeed Forwardbatch_size × \times × sr_len × \times × hidden_size{hidden_size × \times × filter_size} + {filter_size × \times × hidden_size}可学习参数
batch_size × \times × sr_len × \times × hidden_sizeAddNorm2batch_size × \times × sr_len × \times × hidden_sizeNone

训练阶段 Decoder

input_sizeLayeroutput_sizeLayer_parameter_sizeNote
batch_size × \times × tar_lenOutput Embeddingbatch_size × \times × tar_len × \times × embed_sizetar_vocab_size × \times × embed_sizeEmbedding层的参数即可设为可学习的,也可设为固定参数
batch_size × \times × tar_len × \times × embed_sizePostion Embeddingbatch_size × \times × tar_len × \times × embed_size1 × \times × tar_len × \times × embed_size固定参数
batch_size × \times × tar_len × \times × embed_sizeMasked MultiHead Attentionbatch_size × \times × tar_len × \times × hidden_size{embed_size × \times × hidden_size} × \times × 3 + {hidden_size × \times × hidden_size}可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm1batch_size × \times × tar_len × \times × hidden_sizeNone
batch_size × \times × tar_len × \times × hidden_sizeEncoder-Decoder MultiHead Attentionbatch_size × \times × tar_len × \times × hidden_size{hidden_size × \times × hidden_size} × \times × 4可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm2batch_size × \times × tar_len × \times × hidden_sizeNone
batch_size × \times × tar_len × \times × hidden_sizeFeed Forwardbatch_size × \times × tar_len × \times × hidden_size{hidden_size × \times × filter_size} + {filter_size × \times × hidden_size}可学习参数
batch_size × \times × tar_len × \times × hidden_sizeAddNorm3batch_size × \times × tar_len × \times × hidden_sizeNone

注意到,为了保持encoder及decoder的层可以堆叠,需要保证每个层的输入和输出的维度一致,因此,需要保证 embed_size = hidden_size


预测阶段

预测阶段的 encoder 与训练阶段是相同的,只是 batch_size = 1;而 decoder 部分由于每个 step 只能看到当前位置之前的信息,因此每次输入的 tar_len 也等于 1。

预测阶段 Decoder

input_sizeLayeroutput_size
1 × \times × 1Output Embedding1 × \times × 1 × \times × embed_size
1 × \times × 1 × \times × embed_sizePostion Embedding1 × \times × 1 × \times × embed_size
1 $\times$1 × \times × embed_sizeMasked MultiHead Attention1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm11 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeEncoder-Decoder MultiHead Attention1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm21 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeFeed Forward1 × \times × 1 × \times × hidden_size
1 × \times × 1 × \times × hidden_sizeAddNorm31 × \times × 1 × \times × hidden_size

Multihead Attention解析

训练阶段

Encoder Multihead Attention

在这里插入图片描述

  1. Input: Encoder Multihead Attention 输入的 query, key, value 是相同的,都是经过了word embedding和pos embedding之后的 source sentence,其维度为 batch_size × sr_len × hidden_size \text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size} batch_size×sr_len×hidden_size 。由于有 num_heads 个头需要并行计算,首先 query, key, value 分别经过一个线性变换,再将数据 split 给 num_heads 个头分别做注意力查询,即:
    q u e r y : batch_size × sr_len_q × hidden_size ⟹ 线性变换 batch_size × sr_len_q × hidden_size ⟹ reshape batch_size × num_heads × sr_len_q × hidden_size num_heads k e y : batch_size × sr_len_k × hidden_size ⟹ 线性变换 batch_size × sr_len_k × hidden_size ⟹ reshape batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × sr_len_v × hidden_size ⟹ 线性变换 batch_size × sr_len_v × hidden_size ⟹ reshape batch_size × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} query:batch_size×sr_len_q×hidden_size线性变换batch_size×sr_len_q×hidden_sizereshapebatch_size×num_heads×sr_len_q×num_headshidden_sizekey:batch_size×sr_len_k×hidden_size线性变换batch_size×sr_len_k×hidden_sizereshapebatch_size×num_heads×sr_len_k×num_headshidden_sizevalue:batch_size×sr_len_v×hidden_size线性变换batch_size×sr_len_v×hidden_sizereshapebatch_size×num_heads×sr_len_v×num_headshidden_size

由于query, key, value 是相同的,因此有 sr_len_q = sr_len_k = sr_len_v

  1. DotProductAttention: num_heads 个头的计算是并行的,即:
    q u e r y : batch_size × num_heads × sr_len_q × hidden_size num_heads k e y : batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × num_heads × sr_len_v × hidden_size num_heads ⇓ q u e r y ∗ k e y T = batch_size × num_heads × sr_len_q × sr_len_k ⇓ 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ∗ k e y T ) = batch_size × num_heads × sr_len_q × sr_len_k ⇓ masked_softmax ( q u e r y ∗ k e y T ) ∗ v a l u e = batch_size × num_heads × sr_len_q × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{num\_heads} &\times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \Downarrow\\ \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} query:batch_size×num_headskey:batch_size×num_headsvalue:batch_size×num_headsquerykeyT=batch_size×num_headskeypaddingmaskmasked_softmax(querykeyT)=batch_size×num_headsmasked_softmax(querykeyT)value=batch_size×num_heads×sr_len_q×num_headshidden_size×sr_len_k×num_headshidden_size×sr_len_v×num_headshidden_size×sr_len_q×sr_len_k×sr_len_q×sr_len_k×sr_len_q×num_headshidden_size

Encoder Multihead Attention 中在计算 softmax 之前对 key 进行了 mask,目的是消除 padding 的影响。事实上 padding 不仅对 key 有影响,对 query 也有影响,但在实际代码中 mask 仅针对 key,而没有针对 query。其实最原始代码是既有 key mask,也有query mask的,但后来作者将 query mask 删去了,因为在最后计算 loss 的时候对 padding 位置的 loss 进行mask,也可达到相同的效果。

假设 batch_size = num_heads = 1,sr_len_q = sr_len_k = 6,source sentence 的最后两个位置是padding,那么Encoder Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0\\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix} 111111111111111111111111000000000000
即只对 key 的 padding 位置进行了 mask

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    batch_size × num_heads × sr_len_q × hidden_size num_heads ⇓ reshape batch_size × sr_len_q × hidden_size ⇓ 线性变换 batch_size × sr_len_q × hidden_size \begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{sr\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{sr\_len\_q} &\times \text{hidden\_size} \end{aligned} batch_size×num_headsreshapebatch_size×sr_len_q线性变换batch_size×sr_len_q×sr_len_q×num_headshidden_size×hidden_size×hidden_size

Masked Multihead Attention

与 Encoder Multihead Attention 类似,Masked Multihead Attention 输入的 query, key, value 也是相同的,都是经过了word embedding和pos embedding之后的 target sentence。包括后面的计算流程也基本一致。

主要的区别在于:由于在 inference 时,每个 step 位置只能看到它之前的 steps 的信息,而看不到它之后的 steps的信息。因此 Masked Multihead Attention 中的 mask 除了要消除 key 信息里 padding 的影响,还需要消除当前 step 后面的所有 step 的信息:

假设 batch_size = num_heads = 1,tar_len_q = tar_len_k = 5,target sentence 的最后两个位置是 padding,那么Masked Multihead Attention 中的 mask 为:
( 1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 0 0 1 1 1 0 0 ) \begin{pmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \end{pmatrix} 1111101111001110000000000
注意到上述 mask 并不是一个单纯的下三角矩阵,因为最后两个位置都是padding,因此无论如何都要被 mask 掉


Encoder-Decoder Multihead Attention

  1. Input: Encoder-Decoder Multihead Attention 输入的 query 来自于 target sentence,其维度为 batch_size × tar_len × hidden_size \text{batch\_size} \times \text{tar\_len} \times \text{hidden\_size} batch_size×tar_len×hidden_size ;而 key 和 value 则来自于 encoder layer 的输出,其维度为 batch_size × sr_len × hidden_size \text{batch\_size} \times \text{sr\_len} \times \text{hidden\_size} batch_size×sr_len×hidden_size 。同样是先做线性变换,再 split 成 num_heads 个头:
    q u e r y : batch_size × tar_len_q × hidden_size ⟹ 线性变换 batch_size × tar_len_q × hidden_size ⟹ reshape batch_size × num_heads × tar_len_q × hidden_size num_heads k e y : batch_size × sr_len_k × hidden_size ⟹ 线性变换 batch_size × sr_len_k × hidden_size ⟹ reshape batch_size × num_heads × sr_len_k × hidden_size num_heads v a l u e : batch_size × sr_len_v × hidden_size ⟹ 线性变换 batch_size × sr_len_v × hidden_size ⟹ reshape batch_size × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}: \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{tar\_len\_q} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}: \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}: \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{batch\_size} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{batch\_size} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} query:batch_size×tar_len_q×hidden_size线性变换batch_size×tar_len_q×hidden_sizereshapebatch_size×num_heads×tar_len_q×num_headshidden_sizekey:batch_size×sr_len_k×hidden_size线性变换batch_size×sr_len_k×hidden_sizereshapebatch_size×num_heads×sr_len_k×num_headshidden_sizevalue:batch_size×sr_len_v×hidden_size线性变换batch_size×sr_len_v×hidden_sizereshapebatch_size×num_heads×sr_len_v×num_headshidden_size

这里 sr_len_q ≠ \neq = sr_len_k = sr_len_v

  1. DotProductAttention: num_heads 个头的计算依然可以并行:
    q u e r y ∗ k e y T = batch_size × num_heads × tar_len_q × sr_len_k ⇓ 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ∗ k e y T ) = batch_size × num_heads × tar_len_q × sr_len_k ⇓ masked_softmax ( q u e r y ∗ k e y T ) ∗ v a l u e = batch_size × num_heads × tar_len_q × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} querykeyT=batch_size×num_headskeypaddingmaskmasked_softmax(querykeyT)=batch_size×num_headsmasked_softmax(querykeyT)value=batch_size×num_heads×tar_len_q×sr_len_k×tar_len_q×sr_len_k×tar_len_q×num_headshidden_size

假设 batch_size = num_heads = 1,这里sr_len_q可以不等于sr_len_k,不妨假设 sr_len_q = 5,sr_len_k = 6因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 & 0 \end{pmatrix} 111111111111111111110000000000

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    batch_size × num_heads × tar_len_q × hidden_size num_heads ⇓ reshape batch_size × tar_len_q × hidden_size ⇓ 线性变换 batch_size × tar_len_q × hidden_size \begin{aligned} \text{batch\_size} \times \text{num\_heads}& \times \text{tar\_len\_q} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \stackrel{\text{reshape}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size}\\ \stackrel{\text{线性变换}}{\Downarrow}\\ \text{batch\_size} \times \text{tar\_len\_q} &\times \text{hidden\_size} \end{aligned} batch_size×num_headsreshapebatch_size×tar_len_q线性变换batch_size×tar_len_q×tar_len_q×num_headshidden_size×hidden_size×hidden_size

预测阶段

Encoder Multihead Attention

与训练阶段的 Encoder Multihead Attention 完全相同

Masked Multihead Attention

虽然在训练阶段,Masked Multihead Attention 会将当前 step 之后的 steps 信息都 mask 掉,但是由于训练时整个 target sentence 都是已知的,因此还是可以做并行运算的。

但是在预测阶段,初始的 query, key, value 都只是一个 “<bos>” 起始符号,之后每预测出一个 token,这个 token 直接作为下一个 step 输入的 query,而将这个 token 拼在现有的 key 和 value 之后,就是下一个 step 输入的 key 和 value。也就是说,预测阶段每个 step 输入的 query 是上一 step 输出的token,而 key, value 是之前所有 steps 输出的token

至于 mask 的部分,由于输入中不再含有未来 steps 的信息,因此不再需要用 mask 来消除这部分信息。而对于 key mask,由于 Masked Multihead Attention 的 key 是 target sentence,而在预测完成前 target sentence 的长度是未知的,因此针对 key 的 mask 也是不需要的也就是说,Masked Multihead Attention 是不需要 mask 的

下面是预测阶段 Masked Multihead Attention 的流程:

  1. Input: key, value 是到当前 step 为止的所有 steps 的信息,大小为 1 × \times × cur_tar_len × \times × hidden_size;而 query 是上一 step 的输出 token,大小为 1 × \times × 1 × \times × hidden_size:
    q u e r y : 1 × 1 × hidden_size ⟹ 线性变换 1 × 1 × hidden_size ⟹ reshape 1 × num_heads × 1 × hidden_size num_heads k e y : 1 × cur_tar_len_k × hidden_size ⟹ 线性变换 1 × cur_tar_len_k × hidden_size ⟹ reshape 1 × num_heads × cur_tar_len_k × hidden_size num_heads v a l u e : 1 × cur_tar_len_v × hidden_size ⟹ 线性变换 1 × cur_tar_len_v × hidden_size ⟹ reshape 1 × num_heads × cur_tar_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{cur\_tar\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{cur\_tar\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} querykeyvalue:1×1×hidden_size线性变换1×1×hidden_sizereshape1×num_heads×1×num_headshidden_size:1×cur_tar_len_k×hidden_size线性变换1×cur_tar_len_k×hidden_sizereshape1×num_heads×cur_tar_len_k×num_headshidden_size:1×cur_tar_len_v×hidden_size线性变换1×cur_tar_len_v×hidden_sizereshape1×num_heads×cur_tar_len_v×num_headshidden_size

  2. DotProductAttention: num_heads 个头的计算依然可以并行:
    q u e r y ∗ k e y T = 1 × num_heads × 1 × cur_tar_len_k ⇓ softmax ( q u e r y ∗ k e y T ) = 1 × num_heads × 1 × cur_tar_len_k ⇓ softmax ( q u e r y ∗ k e y T ) ∗ v a l u e = 1 × num_heads × 1 × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{cur\_tar\_len\_k}\\ \Downarrow\\ \text{softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} querykeyT=1×num_headssoftmax(querykeyT)=1×num_headssoftmax(querykeyT)value=1×num_heads×1×cur_tar_len_k×1×cur_tar_len_k×1×num_headshidden_size

  3. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    1 × num_heads × 1 × hidden_size num_heads ⇓ reshape 1 × 1 × hidden_size ⇓ 线性变换 1 × 1 × hidden_size \begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned} 1×num_heads1×11×1×1×num_headshidden_sizereshape×hidden_size线性变换×hidden_size


Encoder-Decoder Multihead Attention

预测阶段 Encoder-Decoder Multihead Attention 输入的 query 是上一层 Masked Multihead Attention 的输出,大小为 1 × \times × 1 × \times × hidden_size。而输入的 key 和 value 则是 encoder layer 的输出,大小为:1 × \times × sr_len × \times × hidden_size。具体流程为:

  1. Input:
    q u e r y : 1 × 1 × hidden_size ⟹ 线性变换 1 × 1 × hidden_size ⟹ reshape 1 × num_heads × 1 × hidden_size num_heads k e y : 1 × sr_len_k × hidden_size ⟹ 线性变换 1 × sr_len_k × hidden_size ⟹ reshape 1 × num_heads × sr_len_k × hidden_size num_heads v a l u e : 1 × sr_len_v × hidden_size ⟹ 线性变换 1 × sr_len_v × hidden_size ⟹ reshape 1 × num_heads × sr_len_v × hidden_size num_heads \begin{aligned} \boldsymbol {query}&: \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{1} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {key}&: \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_k} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_k} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \boldsymbol {value}&: \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{线性变换}}{\Longrightarrow } \text{1} \times \text{sr\_len\_v} \times \text{hidden\_size} \stackrel{\text{reshape}}{\Longrightarrow } \text{1} \times \text{num\_heads} \times \text{sr\_len\_v} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ \end{aligned} querykeyvalue:1×1×hidden_size线性变换1×1×hidden_sizereshape1×num_heads×1×num_headshidden_size:1×sr_len_k×hidden_size线性变换1×sr_len_k×hidden_sizereshape1×num_heads×sr_len_k×num_headshidden_size:1×sr_len_v×hidden_size线性变换1×sr_len_v×hidden_sizereshape1×num_heads×sr_len_v×num_headshidden_size

  2. DotProductAttention: num_heads 个头的计算并行:
    q u e r y ∗ k e y T = 1 × num_heads × 1 × sr_len_k ⇓ 消 除 k e y 中 padding 的 影 响 , 对 其 做 mask masked_softmax ( q u e r y ∗ k e y T ) = 1 × num_heads × 1 × sr_len_k ⇓ masked_softmax ( q u e r y ∗ k e y T ) ∗ v a l u e = 1 × num_heads × 1 × hidden_size num_heads \begin{aligned} \boldsymbol {query} * \boldsymbol {key}^T = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \stackrel{消除 \boldsymbol {key} 中 \text{padding} 的影响,对其做 \text{mask}}{\Downarrow}\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) = \text{1} \times \text{num\_heads}& \times \text{1} \times \text{sr\_len\_k}\\ \Downarrow\\ \text{masked\_softmax}(\boldsymbol {query} * \boldsymbol {key}^T) * \boldsymbol {value} = \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} } \end{aligned} querykeyT=1×num_headskeypaddingmaskmasked_softmax(querykeyT)=1×num_headsmasked_softmax(querykeyT)value=1×num_heads×1×sr_len_k×1×sr_len_k×1×num_headshidden_size

假设 num_heads = 1,sr_len_k = 6,因为 mask 只针对key,因此这里只需要关注 source sentence 中的padding, 假设 source sentence 的最后两个位置是padding,那么Masked Multihead Attention 中的 mask 为:
( 1 1 1 1 0 0 ) \begin{pmatrix} 1 & 1 & 1 & 1 & 0 & 0 \\ \end{pmatrix} (111100)

  1. Output: 需要将上面输出的 num_heads 个头的结果堆叠之后,再做一个线性变换:
    1 × num_heads × 1 × hidden_size num_heads ⇓ reshape 1 × 1 × hidden_size ⇓ 线性变换 1 × 1 × hidden_size \begin{aligned} \text{1} \times \text{num\_heads}& \times \text{1} \times \frac{\text{hidden\_size}}{\text{num\_heads} }\\ &\stackrel{\text{reshape}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size}\\ &\stackrel{\text{线性变换}}{\Downarrow}\\ \text{1} \times \text{1} &\times \text{hidden\_size} \end{aligned} 1×num_heads1×11×1×1×num_headshidden_sizereshape×hidden_size线性变换×hidden_size

由于最后的 Feed Forward 层不改变矩阵大小,至此可以总结一下预测阶段的 Decoder layer,输入是上一 step 输出的 token,大小为 1 × \times × 1 × \times × hidden_size,经过两种MultiHead + Feed Forward 后,大小依然为 1 × \times × 1 × \times × hidden_size,再经过 Linear+Softmax,其输出就是预测的当前 step 的token,而这个 token 又会作为下一个 step 的输入 query。直到达到最大长度,或者输出的 token 是 “<eos>”

  • 16
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值