题意:A为N * K的矩阵,B为K * N的矩阵,执行一下步骤。(N1000, K6)
Step 1: Calculate a new N * N matrix C = A*B.
Step 2: Calculate M = C^(N * N).
Step 3: For each element x in M, calculate x % 6. All the remainders form a new matrix M’.
Step 4: Calculate the sum of all the elements in M’.
题解:矩阵快速幂
如果我们直接算AB的话,会发现直接爆栈了,而且1000*1000的快速幂会超时。
我们可以发现,(AB)N*N = A(BA)N*N-1B,这样就转换成了最多6*6的矩阵BA的快速幂。
#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<algorithm>
#include<queue>
#include<stack>
#include<cmath>
#include<vector>
#include<fstream>
#include<set>
#include<map>
#include<sstream>
#include<iomanip>
#define ll long long
using namespace std;
int n, k;
//矩阵快速幂
const int mod = 6;
struct matrix {
ll mat[6][6];
matrix() { memset(mat, 0, sizeof(mat)); }
matrix operator * (const matrix& b)const {
matrix ans;
for (int i = 0; i < k; i++) {
for (int j = 0; j < k; j++) {
ans.mat[i][j] = 0;
for (int kk = 0; kk < k; kk++) {
ans.mat[i][j] = (ans.mat[i][j] + mat[i][kk] * b.mat[kk][j] % mod + mod) % mod;
}
}
}
return ans;
}
};
matrix q_pow(matrix a, ll b) {
matrix ans;
memset(ans.mat, 0, sizeof(ans.mat));
for (int i = 0; i < k; i++) ans.mat[i][i] = 1;
while (b) {
if (b & 1) ans = ans * a;
b >>= 1;
a = a * a;
}
return ans;
}
int a[1000][6], b[6][1000], c[1000][6], ans[1000][1000];
int main() {
while (~scanf("%d%d", &n, &k) && n) {
memset(ans, 0, sizeof(ans));
memset(c, 0, sizeof(c));
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
scanf("%d", &a[i][j]);
}
}
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j++) {
scanf("%d", &b[i][j]);
}
}
matrix C;
for (int i = 0; i < k; i++) {
for (int j = 0; j < k; j++) {
for (int p = 0; p < n; p++) {
C.mat[i][j] += b[i][p] * a[p][j];
}
}
}
C = q_pow(C, n * n - 1);
for (int i = 0; i < n; i++) {
for (int j = 0; j < k; j++) {
for (int p = 0; p < k; p++) {
c[i][j] += a[i][p] * C.mat[p][j];
}
}
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int p = 0; p < k; p++) {
ans[i][j] += c[i][p] * b[p][j];
}
}
}
int x = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
x += ans[i][j] % mod;
}
}
printf("%d\n", x);
}
return 0;
}