【CSP试题回顾】202305-2-矩阵运算

CSP-202305-2-矩阵运算

关键点总结:改变矩阵计算顺序优化时间复杂度

通过先计算 K T × V K ^ T \times V KT×V 而不是先计算 Q × K T Q \times K ^ T Q×KT,有效地减少了计算时间,特别是在处理长序列时。这种优化通常在数据维度一不等时有显著效果,特别是当序列长度显著大于向量维度时。

1.原始的计算顺序

在Transformer的自注意力机制中,给定矩阵 Q Q Q(查询), K K K(键)和 V V V(值),计算首先涉及到以下步骤:

  1. 计算 Q × K T Q \times K ^ T Q×KT(查询和键的点积),得到注意力得分矩阵。
  2. 将注意力得分矩阵乘以 V V V(值矩阵),得到加权的值,这是最终的输出。

原始计算的时间复杂度主要由 Q × K T Q \times K ^ T Q×KT 的计算决定,这个操作的时间复杂度为 O ( n 2 ⋅ d ) O(n ^ 2 \cdot d) O(n2d),其中 n n n 是序列长度(例如,句子中的单词数量或向量数量), d d d 是向量的维度。当 n n n 很大时,这个操作非常耗时。

2.代码中的计算顺序

代码中采取了不同的计算顺序:

  1. 首先,它通过与 W W W 相乘来调整 Q Q Q 中的每个元素(这对应于自注意力机制中的缩放操作,但在这个特定的实现中, W W W 似乎用于不同的目的,比如加权或转换,这并不是标准的自注意力机制的一部分)。

  2. 然后,它计算 K T × V K ^ T \times V KT×V,这个操作的时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d ^ 2) O(nd2),因为它是在矩阵 K T K ^ T KT(维度 d × n d \times n d×n)和矩阵 V V V(维度 n × d n \times d n×d)之间进行的。

  3. 最后,它计算调整后的 Q Q Q K T × V K ^ T \times V KT×V 的结果,时间复杂度为 O ( n ⋅ d 2 ) O(n \cdot d ^ 2) O(nd2)

3.时间复杂度比较

n > d n > d n>d(即,序列长度大于向量维度)时,代码中的计算顺序比原始计算顺序更有效率。原始方法的复杂度主要是由序列长度的平方决定的,而代码中的方法将这个平方项降低到了 n n n d d d 的乘积,这在大多数实际情况下会减少计算量,尤其是在处理长序列时。

解题思路

搞清楚上面的点后,本质上就是简单的矩阵乘法,留意本题关于矩阵点乘计算规则的定义即可。

完整代码

#include <iostream>
#include <vector>
#include <string>
using namespace std;

int main() {  

    ios_base::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    long long n, d;
    cin >> n >> d;

    vector<vector<long long>>Q(n, vector<long long>(d));
    vector<vector<long long>>K_T(d, vector<long long>(n));
    vector<vector<long long>>V(n, vector<long long>(d));
    vector<long long>W(n);

    // 输入Q
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < d; j++)
        {
            cin >> Q[i][j];
        }
    }
    // 输入K_T
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < d; j++)
        {
            cin >> K_T[j][i];
        }
    }
    // 输入V
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < d; j++)
        {
            cin >> V[i][j];
        }
    }
    // 输入W
    for (int i = 0; i < n; i++)
    {
        cin >> W[i];
    }

    // 计算 W * Q
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < d; j++)
        {
            Q[i][j] *= W[i];
        }
    }
     
    // 计算 K_T * V
    vector<vector<long long>>T1(d, vector<long long>(d));
    for (int i = 0; i < d; i++)
    {
        for (int j = 0; j < d; j++)
        {
            for (int k = 0; k < n; k++)
            {
                T1[i][j] += K_T[i][k] * V[k][j];
            }
        }
    }

    // 计算 Q * T1
    vector<vector<long long>>T2(n, vector<long long>(d));
    for (int i = 0; i < n; i++)
    {
        for (int j = 0; j < d; j++)
        {
            for (int k = 0; k < d; k++)
            {
                T2[i][j] += Q[i][k] * T1[k][j];
            }
        }
    }

    for (const auto& it : T2) {
        for (const auto& jt : it) {
            cout << jt << " ";
        }
        cout << endl;
    }

    return 0;
}

请添加图片描述

  • 27
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值