题目背景
Softmax(Q×KTd)×V 是 Transformer 中注意力模块的核心算式,其中 Q、K 和 V 均是 n 行 d 列的矩阵,KT 表示矩阵 K 的转置,× 表示矩阵乘法。
问题描述
为了方便计算,顿顿同学将 Softmax 简化为了点乘一个大小为 n 的一维向量 W:
(W⋅(Q×KT))×V
点乘即对应位相乘,记 W(i) 为向量 W 的第 i 个元素,即将 (Q×KT) 第 i 行中的每个元素都与 W(i) 相乘。
现给出矩阵 Q、K 和 V 和向量 W,试计算顿顿按简化的算式计算的结果。
输入格式
从标准输入读入数据。
输入的第一行包含空格分隔的两个正整数 n 和 d,表示矩阵的大小。
接下来依次输入矩阵 Q、K 和 V。每个矩阵输入 n 行,每行包含空格分隔的 d 个整数,其中第 i 行的第 j 个数对应矩阵的第 i 行、第 j 列。
最后一行输入 n 个整数,表示向量 W。
输出格式
输出到标准输出中。
输出共 n 行,每行包含空格分隔的 d 个整数,表示计算的结果。
样例输入
3 2
1 2
3 4
5 6
10 10
-20 -20
30 30
6 5
4 3
2 1
4 0 -5
样例输出
480 240
0 0
-2200 -1100
子任务
70 的测试数据满足:n≤100 且 d≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。
全部的测试数据满足:n≤104 且 d≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。
提示
请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。
题目分析
按照题目步骤进行矩阵操作即可。
代码示例:得分70,运行超时
#include<iostream>
#include<vector>
using namespace std;
int n, d;//n行d列
//输入矩阵
void cinMatrix(vector<vector<int> > &x) {
int i, j;
int a;
for (i = 0; i < n; i++)
{
vector<int> m;
for (j = 0; j < d; j++)
{
cin >> a;
m.push_back(a);
}
x.push_back(m);
}
}
//不需要转置矩阵,直接乘
void matrix(vector<vector<int> >& x, vector<vector<int> > y)
{
vector<vector<int> >m;
int a = 0;
int i, j, z;
for (i = 0; i < n; i++)
{
vector<int> k;
for (z = 0; z < n; z++)
{
for (j = 0; j < d; j++)
{
a += x[i][j] * y[z][j];
}
k.push_back(a);
a = 0;
}
m.push_back(k);
}
x = m;
}
//n*n的矩阵和n*d的矩阵
void matrix1(vector<vector<int> >& x, vector<vector<int> > y)
{
vector<vector<int> >m;
int a = 0;
int i, j, z;
for (i = 0; i < n; i++)
{
vector<int> k;
for (z = 0; z < d; z++)
{
for (j = 0; j < n; j++)
{
a += x[i][j] * y[j][z];
}
k.push_back(a);
a = 0;
}
m.push_back(k);
}
x = m;
}
int main() {
cin >> n >> d;
int i, j;
//转置矩阵:将矩阵的行列互换得到的新矩阵称为转置矩阵,转置矩阵的行列式不变。
//输入Q\K\V
vector<vector<int> >Q;
vector<vector<int> >K;
vector<vector<int> >V;
cinMatrix(Q);
cinMatrix(K);
cinMatrix(V);
vector<int> w;
int a = 0;//临时变量
for (i = 0; i < n; i++)
{
cin >> a;
w.push_back(a);
}
//zK(K);
matrix(Q, K);
for (i = 0; i < n; i++)
{
for (j = 0; j < n; j++)
{
Q[i][j] *= w[i];
}
}
matrix1(Q, V);
//输出
for (i = 0; i < n; i++)
{
for (j = 0; j < d; j++)
cout << Q[i][j] << ' ';
cout << endl;
}
return 0;
}
代码优化:得分100,结果正确
1)超时的原因在于先算Q和K转置的乘积。矩阵乘法满足结合律,可以先算K和V的乘积进行优化,点乘运算不影响结果。
具体原因分析:
a.对于K转置矩阵*V(d*n和n*d),得出一个元素要进行的乘法次数为n次,然后把这n次乘法相加;得出一行元素要进行的乘法次数为nd;得出整个矩阵要进行的乘法次数为ndd;
b.对于Q*上面得出的矩阵(n*d和d*d),得出一个元素要进行的乘法次数为d次,然后把这d次乘法相加;得出一行元素要进行的乘法次数为dd;得出整个矩阵要进行的乘法次数为ndd;
c.上述总共进行的乘法次数为(ndd)*(ndd)
再来看没有优化前的情况:
a.对于Q*K转置矩阵(n*d和d*n),得出一个元素要进行的乘法次数为d次,然后把这d次乘法相加;得出一行元素要进行的乘法次数为nd;得出整个矩阵要进行的乘法次数为nnd;
b.对于上面得出的矩阵*V(n*n和n*d),得出一个元素要进行的乘法次数为n次,然后把这n次乘法相加;得出一行元素要进行的乘法次数为nd;得出整个矩阵要进行的乘法次数为nnd;
c.上述总共进行的乘法次数为(nnd)*(nnd)
由于n和d的取值范围已知,没有优化的乘法次数是优化后的(n/d)的平方,最高可达10的5次方倍,啧啧,时间太长了。
2)还需要注意取值的范围,假设每个元素都为1000,n=10000,d=20,那么对于K转置矩阵*V的d*d矩阵的每个元素已经为10的10次方,再继续和前面相乘,显然已经超过int型的表示范围,要选择恰当的数据类型long long。
3)其中比较容易绕的点在于,理清矩阵乘法的行列规则,在循环中注意n和d是哪个矩阵的行和列。
#include<iostream>
#include<vector>
using namespace std;
int n, d;//n行d列
//输入矩阵
void cinMatrix(vector<vector<long long> >& x) {
int i, j;
long long a;
for (i = 0; i < n; i++)
{
vector<long long> m;
for (j = 0; j < d; j++)
{
cin >> a;
m.push_back(a);
}
x.push_back(m);
}
}
//不需要转置矩阵,直接乘n*d和n*d,不转置,应该为列乘列
void matrix(vector<vector<long long> >& x, vector<vector<long long> > y)
{
vector<vector<long long> >m;
long long a = 0;
int i, j, z;
for (i = 0; i < d; i++)
{
vector<long long> k;
for (z = 0; z < d; z++)
{
for (j = 0; j < n; j++)
{
a += x[j][i] * y[j][z];
}
k.push_back(a);
a = 0;
}
m.push_back(k);
}
x = m;
}
//n*d的矩阵和d*d的矩阵
void matrix1(vector<vector<long long> >& x, vector<vector<long long> > y)
{
vector<vector<long long> >m;
long long a = 0;
int i, j, z;
for (i = 0; i < n; i++)
{
vector<long long> k;
for (z = 0; z < d; z++)
{
for (j = 0; j < d; j++)
{
a += x[i][j] * y[j][z];
}
k.push_back(a);
a = 0;
}
m.push_back(k);
}
x = m;
}
int main() {
cin >> n >> d;
int i, j;
//转置矩阵:将矩阵的行列互换得到的新矩阵称为转置矩阵,转置矩阵的行列式不变。
//输入Q\K\V
vector<vector<long long> >Q;
vector<vector<long long> >K;
vector<vector<long long> >V;
cinMatrix(Q);
cinMatrix(K);
cinMatrix(V);
vector<long long> w;
long long a = 0;//临时变量
for (i = 0; i < n; i++)
{
cin >> a;
w.push_back(a);
}
matrix(K, V);
matrix1(Q, K);
for (i = 0; i < n; i++)
{
for (j = 0; j < d; j++)
{
Q[i][j] *= w[i];
}
}
//输出
for (i = 0; i < n; i++)
{
for (j = 0; j < d; j++)
cout << Q[i][j] << ' ';
cout << endl;
}
return 0;
}