多头注意力用单元矩阵实现以及原因

一、 PyTorch 中实现这个过程。

**本质是在横向加Q1,Q2**

1. 创建单一的权重矩阵

假设我们有以下参数:

  • input_dim = 64:输入的维度。
  • num_heads = 8:注意力头的数量。
  • head_dim = 16:每个头的维度。

我们可以先创建一个大的权重矩阵 W_query,然后进行拆分。

import torch
import torch.nn as nn

# 假设输入维度是 64,每个头的维度是 16,总共 8 个头
input_dim = 64
num_heads = 8
head_dim = 16

# 创建一个大的 W_query 矩阵,形状为 (input_dim, num_heads * head_dim)
W_query = nn.Linear(input_dim, num_heads * head_dim)

# 假设我们有一个输入 X,形状为 (batch_size, seq_len, input_dim)
batch_size = 32
seq_len = 10
X = torch.rand(batch_size, seq_len, input_dim)

# 通过线性层得到 Q 矩阵,形状为 (batch_size, seq_len, num_heads * head_dim)
Q = W_query(X)  # Q 的形状为 (batch_size, seq_len, num_heads * head_dim)

# 将 Q 拆分成多个头,每个头的形状为 (batch_size, seq_len, head_dim)
# 拆分后的形状为 (batch_size, seq_len, num_heads, head_dim)
Q = Q.view(batch_size, seq_len, num_heads, head_dim)

# 如果需要将 Q 传递给每个头独立处理,可以对维度进行调整,形状为 (batch_size, num_heads, seq_len, head_dim)
Q = Q.permute(0, 2, 1, 3)

2. 拆分后的矩阵如何用于多头注意力机制

拆分后的 Q 矩阵可以直接用于每个注意力头的计算,其他的 KV 矩阵也可以类似地处理。

# 例如,假设我们有 K 和 V 矩阵
W_key = nn.Linear(input_dim, num_heads * head_dim)
W_value = nn.Linear(input_dim, num_heads * head_dim)

K = W_key(X).view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
V = W_value(X).view(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)

# 计算每个头的注意力得分
attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)  # 形状为 (batch_size, num_heads, seq_len, seq_len)

# 对注意力得分进行 softmax 操作
attention_weights = torch.softmax(attention_scores, dim=-1)

# 计算每个头的注意力输出
attention_output = torch.matmul(attention_weights, V)  # 形状为 (batch_size, num_heads, seq_len, head_dim)

# 最后将所有头的输出拼接在一起,形成最终的输出
attention_output = attention_output.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_len, num_heads * head_dim)

3、小结

  1. 创建单一权重矩阵:我们创建一个大矩阵,其输出维度是 num_heads * head_dim
  2. 拆分矩阵:通过 .view.permute 方法,将这个大矩阵的输出拆分为多个小矩阵,每个小矩阵对应一个注意力头。
  3. 计算注意力:每个头独立计算注意力得分,然后将结果合并。

你提出的问题实际上涉及到参数共享的误解。让我更详细地解释这个问题,澄清在创建独立矩阵和共享矩阵时参数量的实际情况。

二、 参数量

无论是创建独立矩阵还是共享矩阵,参数量都是相同的。每种方法的参数量计算结果都是 8192。

假设:

  • input_dim = 64:输入的维度。
  • num_heads = 8:注意力头的数量。
  • head_dim = 16:每个头的维度。
1. 独立矩阵的情况

如果为每个注意力头创建独立的权重矩阵,则每个头的权重矩阵的参数量为 input_dim * head_dim。对于 num_heads 个头,整个参数量为:

总参数量 = num_heads × input_dim × head_dim \text{总参数量} = \text{num\_heads} \times \text{input\_dim} \times \text{head\_dim} 总参数量=num_heads×input_dim×head_dim

根据假设,参数量计算如下:

总参数量 = 8 × 64 × 16 = 8192 \text{总参数量} = 8 \times 64 \times 16 = 8192 总参数量=8×64×16=8192

2. 共享矩阵的情况

如果我们使用一个大的共享矩阵,并将其拆分为多个头,则共享矩阵的形状为 (input_dim, num_heads * head_dim)。因此,整个矩阵的参数量为:

总参数量 = input_dim × ( num_heads × head_dim ) \text{总参数量} = \text{input\_dim} \times (\text{num\_heads} \times \text{head\_dim}) 总参数量=input_dim×(num_heads×head_dim)

根据假设,参数量计算如下:

总参数量 = 64 × ( 8 × 16 ) = 64 × 128 = 8192 \text{总参数量} = 64 \times (8 \times 16) = 64 \times 128 = 8192 总参数量=64×(8×16)=64×128=8192

三、 为什么使用共享矩阵

  1. 代码简洁性:使用共享矩阵可以简化代码,实现统一的矩阵运算,减少了手动管理多个独立矩阵的复杂性。

  2. 计算效率:共享矩阵可以利用并行计算的优势,使得在进行矩阵运算时更加高效,因为只需一次大的矩阵乘法操作,而不需要为每个头分别计算。

  3. 实现上的一致性:共享矩阵的实现方式与多头注意力机制的逻辑一致(即所有头都是从同一个大的线性变换中分离出来的),更符合框架的设计和优化原则。

为了帮助你更直观地理解共享矩阵与独立矩阵的关系,我会举一个具体的数值例子,说明共享矩阵在进行一次性矩阵乘法后分解成多个独立矩阵的结果与独立矩阵逐一计算的结果是相同的。

四 、数值示例直观说明是结果是一致的

设定参数

假设我们有以下参数:

  • input_dim = 4:输入的维度。
  • num_heads = 2:注意力头的数量。
  • head_dim = 2:每个头的维度。
  • 输入矩阵 X 的形状为 (batch_size=1, seq_len=3, input_dim=4)
输入矩阵 X

X = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] X = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} X= 159261037114812

独立矩阵的权重

我们为每个头单独创建 W_query_1W_query_2 矩阵:
W _ q u e r y _ 1 = [ 1 0 0 1 1 0 0 1 ] , W _ q u e r y _ 2 = [ 2 1 1 2 2 1 1 2 ] W\_query\_1 = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad W\_query\_2 = \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} W_query_1= 10100101 ,W_query_2= 21211212

1. 独立矩阵逐一计算

我们对每个头的 Q 进行计算:

  • 对第一个头计算 Q_1
    Q 1 = X × W _ q u e r y _ 1 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 0 1 1 0 0 1 ] = [ 4 6 12 14 20 22 ] Q_1 = X \times W\_query\_1 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} Q1=X×W_query_1= 159261037114812 × 10100101 = 4122061422

  • 对第二个头计算 Q_2
    Q 2 = X × W _ q u e r y _ 2 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 2 1 1 2 2 1 1 2 ] = [ 10 16 30 40 50 64 ] Q_2 = X \times W\_query\_2 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} Q2=X×W_query_2= 159261037114812 × 21211212 = 103050164064

2. 共享矩阵一次性计算

我们将上述两个矩阵组合为一个大的共享矩阵 W_query,其形状为 (input_dim=4, num_heads * head_dim=4)
W _ q u e r y = [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] W\_query = \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} W_query= 1010010121211212

然后一次性计算出所有头的 Q 矩阵:
Q = X × W _ q u e r y = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] Q = X \times W\_query = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} Q=X×W_query= 159261037114812 × 1010010121211212 = 4122061422103050164064

3. 将共享矩阵结果拆分为独立矩阵

最后,我们将 Q 矩阵按列拆分成两个子矩阵,每个子矩阵对应一个注意力头:

  • 第一个头的结果 Q_1
    Q 1 = [ 4 6 12 14 20 22 ] Q_1 = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} Q1= 4122061422

  • 第二个头的结果 Q_2
    Q 2 = [ 10 16 30 40 50 64 ] Q_2 = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} Q2= 103050164064

4. 对比结果

从上述步骤可以看到,使用共享矩阵一次性计算并拆分后的结果与使用独立矩阵逐一计算的结果完全相同。具体来说:

  • Q_1 的结果在独立和共享情况下都为:
    [ 4 6 12 14 20 22 ] \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} 4122061422

  • Q_2 的结果在独立和共享情况下都为:
    [ 10 16 30 40 50 64 ] \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} 103050164064

计算 QK 的点积,即 Q * K^T,以展示共享矩阵和独立矩阵在这一步的等价性。

假设

继续沿用之前的参数和矩阵:

  • 输入矩阵 X 的形状为 (batch_size=1, seq_len=3, input_dim=4)
  • W_queryW_key 的设置方式与之前相同,分为独立矩阵和共享矩阵两种情况。

1. 独立矩阵的计算

假设我们有独立的 W_key_1W_key_2

W _ k e y _ 1 = W _ q u e r y _ 1 = [ 1 0 0 1 1 0 0 1 ] , W _ k e y _ 2 = W _ q u e r y _ 2 = [ 2 1 1 2 2 1 1 2 ] W\_key\_1 = W\_query\_1 = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad W\_key\_2 = W\_query\_2 = \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} W_key_1=W_query_1= 10100101 ,W_key_2=W_query_2= 21211212

首先计算独立的 K_1K_2

  • 对第一个头计算 K_1
    K 1 = X × W _ k e y _ 1 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 0 1 1 0 0 1 ] = [ 4 6 12 14 20 22 ] K_1 = X \times W\_key\_1 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} K1=X×W_key_1= 159261037114812 × 10100101 = 4122061422

  • 对第二个头计算 K_2
    K 2 = X × W _ k e y _ 2 = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 2 1 1 2 2 1 1 2 ] = [ 10 16 30 40 50 64 ] K_2 = X \times W\_key\_2 = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 2 & 1 \\ 1 & 2 \\ 2 & 1 \\ 1 & 2 \end{bmatrix} = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} K2=X×W_key_2= 159261037114812 × 21211212 = 103050164064

然后计算 Q_1 * K_1^TQ_2 * K_2^T

  • Q_1 * K_1^T
    Q 1 × K 1 T = [ 4 6 12 14 20 22 ] × [ 4 12 20 6 14 22 ] = [ 52 120 188 120 292 464 188 464 740 ] Q_1 \times K_1^T = \begin{bmatrix} 4 & 6 \\ 12 & 14 \\ 20 & 22 \end{bmatrix} \times \begin{bmatrix} 4 & 12 & 20 \\ 6 & 14 & 22 \end{bmatrix} = \begin{bmatrix} 52 & 120 & 188 \\ 120 & 292 & 464 \\ 188 & 464 & 740 \end{bmatrix} Q1×K1T= 4122061422 ×[4612142022]= 52120188120292464188464740

  • Q_2 * K_2^T
    Q 2 × K 2 T = [ 10 16 30 40 50 64 ] × [ 10 30 50 16 40 64 ] = [ 356 820 1284 820 1960 3100 1284 3100 4916 ] Q_2 \times K_2^T = \begin{bmatrix} 10 & 16 \\ 30 & 40 \\ 50 & 64 \end{bmatrix} \times \begin{bmatrix} 10 & 30 & 50 \\ 16 & 40 & 64 \end{bmatrix} = \begin{bmatrix} 356 & 820 & 1284 \\ 820 & 1960 & 3100 \\ 1284 & 3100 & 4916 \end{bmatrix} Q2×K2T= 103050164064 ×[101630405064]= 356820128482019603100128431004916

2. 共享矩阵的计算

使用共享矩阵时,先计算出 QK

Q = X × W _ q u e r y = [ 1 2 3 4 5 6 7 8 9 10 11 12 ] × [ 1 0 2 1 0 1 1 2 1 0 2 1 0 1 1 2 ] = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] Q = X \times W\_query = \begin{bmatrix} 1 & 2 & 3 & 4 \\ 5 & 6 & 7 & 8 \\ 9 & 10 & 11 & 12 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} Q=X×W_query= 159261037114812 × 1010010121211212 = 4122061422103050164064

K 的计算与 Q 相同,因为 W_key = W_query

接下来计算 Q * K^T

Q × K T = [ 4 6 10 16 12 14 30 40 20 22 50 64 ] × [ 4 12 20 6 14 22 10 30 50 16 40 64 ] = [ 52 120 188 356 820 1284 120 292 464 820 1960 3100 188 464 740 1284 3100 4916 ] Q \times K^T = \begin{bmatrix} 4 & 6 & 10 & 16 \\ 12 & 14 & 30 & 40 \\ 20 & 22 & 50 & 64 \end{bmatrix} \times \begin{bmatrix} 4 & 12 & 20 \\ 6 & 14 & 22 \\ 10 & 30 & 50 \\ 16 & 40 & 64 \end{bmatrix} = \begin{bmatrix} 52 & 120 & 188 & 356 & 820 & 1284 \\ 120 & 292 & 464 & 820 & 1960 & 3100 \\ 188 & 464 & 740 & 1284 & 3100 & 4916 \end{bmatrix} Q×KT= 4122061422103050164064 × 4610161214304020225064 = 52120188120292464188464740356820128482019603100128431004916

3. 将共享矩阵结果拆分为独立矩阵结果

我们可以将结果矩阵按头的数量拆分成两个部分:

  • 前两列对应 Q_1 * K_1^T
    [ 52 120 188 120 292 464 188 464 740 ] \begin{bmatrix} 52 & 120 & 188 \\ 120 & 292 & 464 \\ 188 & 464 & 740 \end{bmatrix} 52120188120292464188464740

  • 后四列对应 Q_2 * K_2^T
    [ 356 820 1284 820 1960 3100 1284 3100 4916 ] \begin{bmatrix} 356 & 820 & 1284 \\ 820 & 1960 & 3100 \\ 1284 & 3100 & 4916 \end{bmatrix} 356820128482019603100128431004916

4. 对比结果

从上述结果可以看到,使用共享矩阵一次性计算并拆分后的 Q * K^T 结果与使用独立矩阵逐一计算的结果完全一致:

  • Q_1 * K_1^T 在共享和独立情况下的结果一致。
  • Q_2 * K_2^T 在共享和独立情况下的结果一致。

总结

验证了共享矩阵和独立矩阵在计算 Q * K^T 时的等价性。虽然操作顺序和形式不同,但它们最终得到的结果是完全相同的。

  • 26
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Ai君臣

学会的就要教给人

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值