题目链接
http://118.190.20.162/view.page?gpid=T169
运行结果:
思想
- 根据数据范围推断出,最终结果矩阵中的元素数字可能超过int,需要用long long.当然,用python的话不用考虑这个问题。
- 考虑改变矩阵运算的次序,从而减低时间复杂度。
如果不用long long,会导致结果错误。最好全部矩阵的存储都用long long.
如果不调整矩阵运算次序,会超时。
这里简单说明一下为啥交换次序能优化:
Q:n*d
K:d*n
V:n*d
从前向后算的话,复杂度是:
如果先算Temp = K*V,再算Q*Temp, 复杂度是
满分代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
// 通过矩阵运算交换律就可以做很大的优化
void cross_multiply(vector<vector<ll>>& left, vector<vector<ll>>& right,
vector<vector<ll>>& ans, int a, int b, int c){
// left: a*b right : b*c ans: a*c
ll temp = 0;
for(int i=0;i<a;++i){
for(int j=0;j<c;++j){
// get ans[i][j] left的第i行 right的第j列 对应相乘再相加
temp = 0;
for(int k=0;k<b;++k){
temp += left[i][k]*right[k][j];
}
ans[i][j] = temp;
}
}
}
void dot_multiply(vector<ll>&w, vector<vector<ll>>& ans, int n, int d){
// w:n*1 ans:n*d
for(int i=0;i<n;++i){
for(int j=0;j<d;++j){
ans[i][j]*=w[i];
}
}
}
int main(){
// read data
// cout<<sizeof(int)<<" "<<sizeof(long long)<<endl;
int n, d;
cin >> n >> d;
// q, k, v : n*d w : n * 1
vector<vector<ll>> q (n, vector<ll>(d, 0));
vector<vector<ll>> kt (d, vector<ll>(n, 0));//kt for transpose of k
vector<vector<ll>> v (n, vector<ll>(d, 0));
vector<ll> w(n, 0);
for(int i=0;i<n;++i){
for(int j=0;j<d;++j)
cin >> q[i][j];
}
// kt : d * n
for(int i=0;i<n;++i){
for(int j=0;j<d;++j)
cin >> kt[j][i];
}
for(int i=0;i<n;++i){
for(int j=0;j<d;++j)
cin >> v[i][j];
}
// read w, n*1
for(int i=0;i<n;++i){
cin >> w[i];
}
// 不调整顺序,只过70
// vector<vector<ll>> qkt(n, vector<ll>(n, 0));
// cross_multiply(q, kt, qkt, n, d, n);
// dot_multiply(w, qkt, n, n);
// vector<vector<ll>> ans(n, vector<ll>(d, 0));
// cross_multiply(qkt, v, ans, n, n, d);
// 调整计算顺序后的,过100
vector<vector<ll>> ktv(d, vector<ll>(d,0));
cross_multiply(kt, v, ktv, d, n, d);
vector<vector<ll>> ans(n, vector<ll>(d, 0));
cross_multiply(q, ktv, ans, n, d, d);
dot_multiply(w, ans, n, d);
for(int i=0;i<n;++i){
for(int j=0;j<d;++j) cout<<ans[i][j]<<" ";
cout<<endl;
}
return 0;
}
python:
n, d = map(int, input().split())
q, k, v, w = [], [], [], []
for i in range(n):
row = list(map(int, input().split()))
q.append(row)
for i in range(n):
row = list(map(int, input().split()))
k.append(row)
for i in range(n):
row = list(map(int, input().split()))
v.append(row)
w = list(map(int, input().split()))
def cross_multiply(left, right, a, b, c):
# left:a*b right:b*c return ans
ans = [[0]*c for i in range(a)]
for i in range(a):
for j in range(c):
# get ans[i][j]
temp = 0
for k in range(b):
temp += left[i][k]*right[k][j]
ans[i][j]=temp
return ans
# transpose k
kt = []
for i in range(d):
row = []
for j in range(n):
row.append(k[j][i])
kt.append(row)
ktv = cross_multiply(kt, v, d, n, d)
qktv = cross_multiply(q, ktv, n, d, d)
# 对 qktv 和 w 进行点运算
for i in range(n):
for j in range(d):
qktv[i][j] *= w[i]
for i in range(n):
for j in range(d):
print(qktv[i][j], end=' ')
print()
总结
在这个题目的特殊条件(n<=10**4且d<=20)下,由于n的范围远大于d,所以可以用交换计算顺序的技巧来优化。当然,如果n<=10**4 and d <= 10**4, 那么这个优化就没有意义了。
在这里朴实无华的想法却是最强大的。