优化方法:先k*v,再q*k。(先q*k:n*d×d*n=n*n,改为先k*v:d*n×n*d=d*d,此题d远比n小,所以优化可行)。
#include <bits/stdc++.h>
using namespace std;
typedef long long int ll;
const ll maxn = 10005;
ll q[maxn][25];
ll k[25][maxn];
ll v[maxn][25];
ll w[maxn];
ll n, d;
ll kv[25][25];
ll ww[maxn][25];
int main()
{
cin >> n >> d;
for (ll i = 1; i <= n; i++)
{
for (ll j = 1; j <= d; j++)
{
cin >> q[i][j];
}
}
for (ll i = 1; i <= n; i++)
{
for (ll j = 1; j <= d; j++)
{
cin >> k[j][i];//k转置
}
}
for (ll i = 1; i <= n; i++)
{
for (ll j = 1; j <= d; j++)
{
cin >> v[i][j];
}
}
for (ll i = 1; i <= n; i++)
{
cin >> w[i];
}
// k*v
ll ans = 0;
for (ll i = 1; i <= d; i++)
{
for (ll j = 1; j <= d; j++)
{
for (ll kk = 1; kk <= n; kk++)
{
ans += k[i][kk] * v[kk][j];
}
kv[i][j] = ans;
ans = 0;
}
}
for (ll i = 1; i <= n; i++)
{
for (ll j = 1; j <= d; j++)
{
for (ll kk = 1; kk <= d; kk++)
{
ans += q[i][kk] * kv[kk][j];
}
ww[i][j] = ans;
ans = 0;
}
}
for (ll i = 1; i <= n; i++)
{
for (ll j = 1; j <= d; j++)
{
ww[i][j] *= w[i];
cout << ww[i][j] << " ";
}
cout << "\n";
}
}