题意
给出一个 n * k 的矩阵A, 一个 k * n 的矩阵B ( 4 <= n <= 1000 ) (2 <= k<= 6)
进行以下操作 :
1. 计算n * n的矩阵C = A * B
2. 计算矩阵 M = Cn∗nCn∗n
3. 矩阵M中每个元素模6得到M’
4. 计算M’中每个元素的和
思路
比赛期间没想到矩阵M的计算可以这样处理:
M=(A∗B)∗(A∗B)∗(A∗B)∗...∗(A∗B)M=(A∗B)∗(A∗B)∗(A∗B)∗...∗(A∗B)
根据矩阵乘法的结合律
M=A∗(B∗A)∗(B∗A)∗...∗(B∗A)∗BM=A∗(B∗A)∗(B∗A)∗...∗(B∗A)∗B
令C = (B*A)
中间的C矩阵至多是6 * 6矩阵, 就把原来至多1000 * 1000的快速幂简化成为 6 * 6的快速幂
最终的结果即 A * C * B
矩阵快速幂 : 矩阵快速幂模板
AC代码
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <cmath>
#define mst(a) memset(a, 0, sizeof a)
using namespace std;
const int maxn = 1e3+5;
const int mmaxn = 6;
int mod = 6;
typedef long long ll;
int a[maxn][10], b[10][maxn], M[maxn][10], M2[maxn][maxn];
int n, kk;
struct mat{
int s[mmaxn][mmaxn];
mat(){
mst(s);
};
mat operator * (const mat& c) {
mat ans;
for (int i = 0; i < mmaxn; i++) //矩阵乘法
for (int j = 0; j < mmaxn; j++)
for (int k = 0; k < mmaxn; k++)
ans.s[i][j] = (ans.s[i][j] + s[i][k] * c.s[k][j]) % mod;
return ans;
}
}str, c;
mat pow_mod(ll k) {
if (k == 1)
return str;
mat a = pow_mod(k/2);//不能改
mat ans = a * a;
if (k & 1)
ans = ans * str;
return ans;
}
void _baplus(){
for( int k = 0; k < n; k++ )
for( int i = 0; i < kk; i++ )
if( b[i][k] )
for( int j = 0; j < kk; j++ )
str.s[i][j] += b[i][k]*a[k][j];
}
void _acbplus(){
mst(M);
mst(M2);
// m = a*c ( 1000*6*6*6 --> 1000*6 )
for( int k = 0; k < kk; k++ )
for( int i = 0; i < n; i++ )
if( a[i][k] )
for( int j = 0; j < kk; j++ )
M[i][j] += a[i][k]*c.s[k][j];
// m = m*b ( 1000*6*6*1000 --> 1000*1000 )
for( int k = 0; k < kk; k++ )
for( int i = 0; i < n; i++ )
if( M[i][k] )
for( int j = 0; j < n; j++ )
M2[i][j] += M[i][k]*b[k][j];
}
ll getans(){
ll ans = 0;
for( int i = 0; i < n; i++ )
for( int j = 0; j < n; j++ )
ans += M2[i][j]%mod;
return ans;
}
int main()
{
while( scanf("%d%d",&n, &kk) == 2 && n ){
mst(str.s);
mst(c.s);
for( int i = 0; i < n; i++ )
for( int j = 0; j < kk; j++ )
scanf("%d",&a[i][j]);
for( int i = 0; i < kk; i++ )
for( int j = 0; j < n; j++ )
scanf("%d",&b[i][j]);
_baplus();
ll m = n*n-1;
c = pow_mod(m);
_acbplus();
printf("%lld\n",getans());
}
return 0;
}