题目:
样例输入:
3 2 1 2 3 4 5 6 10 10 -20 -20 30 30 6 5 4 3 2 1 4 0 -5
样例输出:
480 240 0 0 -2200 -1100
子任务:
70 %的测试数据满足:n≤100 且 d≤10;输入矩阵、向量中的元素均为整数,且绝对值均不超过 30。
全部的测试数据满足:n≤1e4 且 d≤20;输入矩阵、向量中的元素均为整数,且绝对值均不超过 1000。
提示:
请谨慎评估矩阵乘法运算后的数值范围,并使用适当数据类型存储矩阵中的整数。
分析:
如果按照题目给出的顺序计算:
首先一个n行d列的矩阵与d行n列的矩阵相乘得到一个n行n列的矩阵,计算n*n*d次,n和d均取最大值,即2e9次。一般1e9就有可能超时了。然后该矩阵的每一行与对应的Wi相乘,还是一个n行n列的矩阵,用时n*n次。最后与n*n矩阵与n*d的矩阵相乘,用时n*n*d次。所以按照这个顺序肯定是过不了的。
优化:
d*n 与 n *d 相乘 得到 d*d 的矩阵,用时d*d*n。n*d 与 d*d的矩阵相乘,用时 n*d*d。最后一步用时n*d。按照这样的顺序就能过了。
代码:
#include<bits/stdc++.h>
#include<unordered_map>
using namespace std;
#define int long long
const int N = 1e4 + 5, D = 25;
int q[N][D], k[N][D], V[N][D], w[N], ans[N][D], tsb[D][D];
int n, d;
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> d;
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= d; j++) {
cin >> q[i][j];
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= d; j++) {
cin >> k[i][j];
}
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= d; j++) {
cin >> V[i][j];
}
}
for (int i = 1; i <= n; i++) {
cin >> w[i];
}
for (int i = d; i >= 1; i--) {
for (int j = 1; j <= d; j++) {
int num = 0;
int u = 1, v = 1;
while (u <= n && v <= n) {
num += k[u][i] * V[v][j];
u++, v++;
}
//cout << num << " ";
tsb[i][j] = num;
}
//cout << "\n";
}
for (int i = 1; i <= n; i++) {
for (int j = 1; j <= d; j++) {
int u = 1, v = 1;
while (u <= d && v <= d) {
ans[i][j] += q[i][u] * tsb[v][j];
u++, v++;
}
ans[i][j] *= w[i];
cout << ans[i][j] << " ";
}
cout << "\n";
}
return 0;
}