Transformer多头注意力的计算量和单头注意力计算量比较

多头注意力机制(Multi-Head Attention)在Transformer中引入了多个并行的注意力头,每个注意力头可以学习到不同的特征表示。尽管这种机制增强了模型的表达能力,但也增加了一些计算量。下面详细比较一下多头注意力和单头注意力的计算量。

单头注意力计算量

假设输入序列的长度为 \( n \),词嵌入维度为 \( d_{\text{model}} \)。

1. **线性变换**:
   将输入矩阵 \( \mathbf{X} \) 转换为查询 \( \mathbf{Q} \)、键 \( \mathbf{K} \)、值 \( \mathbf{V} \) 矩阵,每个变换的计算量为:
   \[
   \mathcal{O}(n \times d_{\text{model}} \times d_{\text{model}})
   \]
   由于有三个这样的变换,总计算量为:
   \[
   3 \times \mathcal{O}(n \times d_{\text{model}}^2)
   \]

2. **注意力得分计算**:
   计算查询和键的点积,得到注意力得分矩阵 \( \mathbf{Q} \mathbf{K}^T \),计算量为:
   \[
   \mathcal{O}(n \times d_{\text{model}} \times n)
   \]

3. **softmax和加权求和**:
   计算注意力权重并对值进行加权求和,计算量为:
   \[
   \mathcal{O}(n^2 \times d_{\text{model}})
   \]

多头注意力计算量

假设有 \( h \) 个注意力头,每个头的维度为 \( d_k = \frac{d_{\text{model}}}{h} \)。

1. **线性变换**:
   将输入矩阵 \( \mathbf{X} \) 转换为每个头的查询 \( \mathbf{Q} \)、键 \( \mathbf{K} \)、值 \( \mathbf{V} \) 矩阵,每个变换的计算量为:
   \[
   \mathcal{O}(n \times d_{\text{model}} \times d_k)
   \]
   总计算量为:
   \[
   3 \times h \times \mathcal{O}(n \times d_{\text{model}} \times d_k) = 3 \times \mathcal{O}(n \times d_{\text{model}}^2)
   \]

2. **注意力得分计算**:
   计算每个头的查询和键的点积,得到注意力得分矩阵 \( \mathbf{Q}_i \mathbf{K}_i^T \),每个头的计算量为:
   \[
   \mathcal{O}(n \times d_k \times n)
   \]
   总计算量为:
   \[
   h \times \mathcal{O}(n \times d_k \times n) = \mathcal{O}(n^2 \times d_{\text{model}})
   \]

3. **softmax和加权求和**:
   计算每个头的注意力权重并对值进行加权求和,计算量为:
   \[
   h \times \mathcal{O}(n^2 \times d_k) = \mathcal{O}(n^2 \times d_{\text{model}})
   \]

4. **拼接和线性变换**:
   将所有头的输出拼接并进行一次线性变换,总计算量为:
   \[
   \mathcal{O}(n \times d_{\text{model}} \times d_{\text{model}})
   \]

总计算量比较

- **单头注意力**的总计算量为:
  \[
  3 \times \mathcal{O}(n \times d_{\text{model}}^2) + \mathcal{O}(n^2 \times d_{\text{model}})
  \]

- **多头注意力**的总计算量为:
  \[
  3 \times \mathcal{O}(n \times d_{\text{model}}^2) + \mathcal{O}(n^2 \times d_{\text{model}}) + \mathcal{O}(n \times d_{\text{model}}^2)
  \]
可以简化为:
  \[
  4 \times \mathcal{O}(n \times d_{\text{model}}^2) + \mathcal{O}(n^2 \times d_{\text{model}})
  \]

从上述计算可以看出,多头注意力机制在计算量上增加了一倍的线性变换计算量,但主要的计算瓶颈仍然在 \(\mathcal{O}(n^2 \times d_{\text{model}})\) 这一项上,即注意力得分的计算和softmax的加权求和。这表明多头注意力虽然计算量更大,但并不会成倍增加复杂度,而是通过增加计算资源来增强模型的表达能力和学习不同特征的能力。
 

### 自注意力机制的时间复杂度计算 自注意力机制的核心在于通过键值查询的方式捕获序列中的依赖关系。假设输入序列为长度 \( n \),每个 token 的维度为 \( d \)。 #### 计算过程 1. **矩阵乘法操作** 在自注意力机制中,首先需要计算 Query (Q), Key (K), Value (V) 矩阵。这些矩阵由输入嵌入经过线性变换得到。对于每一对 Q K 向量之间的点积运算,其时间复杂度为 \( O(d) \)[^1]。由于有 \( n \times n \) 对这样的向量组合,因此总的点积计算复杂度为 \( O(n^2d) \)。 2. **Softmax 归一化** Softmax 函数应用于上述点积的结果上,以获得权重分布。此步骤涉及对每一行执行指数函数操作,总共有 \( n \) 行,每行包含 \( n \) 个元素。所以该部分的时间复杂度大约也是 \( O(n^2) \)[^1]。 3. **加权求** 使用 softmax 输出作为权重来聚合 V 值。这一阶段涉及到将大小为 \( n \times d \) 的矩阵相乘一次,从而带来额外的 \( O(nd) \) 时间开销。 综上所述,整个 self-attention 层的主要瓶颈来自于第一步的大规模点积运算,最终整体时间复杂度可表示为: \[ O(n^2d) \] 这是标准实现下的理论上限,在实际应用中可能会因为硬件优化等因素有所变化。 ```python import numpy as np def calculate_self_attention(Q, K, V): """ A simple implementation to demonstrate the computational cost. Parameters: Q : numpy array of shape (n, d_k) The query matrix. K : numpy array of shape (n, d_k) The key matrix. V : numpy array of shape (n, d_v) The value matrix. Returns: output: numpy array of shape (n, d_v) Result after applying attention mechanism. """ scores = np.dot(Q, K.T) / np.sqrt(K.shape[1]) # Shape: (n, n); Time Complexity: O(n^2 * d_k) distribution = np.softmax(scores, axis=-1) # Time Complexity: O(n^2) output = np.dot(distribution, V) # Time Complexity: O(n * d_v) return output ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值