十个利用矩阵乘法解决的经典题目的最后一题
1A
终于刷完十道题了,基本都是一个思路:二分优化!
a^1 + a^2 + a^3 + a^4 + a^5 + a^6 + a^7 + a^8 用二分变成 a^1 + a^2 + a^3 + a^4 + a^4 *( a^1 + a^2 + a^3 + a^4 )
a^1 + a^2 + a^3 + a^4 + a^5 + a^6 + a^7 + a^8 + a^9 用二分变成 a^1 + a^2 + a^3 + a^4 + a^4 *( a^1 + a^2 + a^3 + a^4 )+ a^9
AC代码如下:
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAX_N = 33;
int MOD, N, K;
void multipy( int a[MAX_N][MAX_N], int am, int an, int b[MAX_N][MAX_N], int bm, int bn, int c[MAX_N][MAX_N] ){
for( int i = 1; i <= am; i++ ){
for( int j = 1; j <= bn; j++ ){
c[i][j] = 0;
for( int k = 1; k <= an; k++ ){
c[i][j] = ( c[i][j] + a[i][k] * b[k][j] ) % MOD;
}
}
}
}
void add_matrix( int a[MAX_N][MAX_N], int b[MAX_N][MAX_N], int m, int n ){
for( int i = 1; i <= m; i++ ){
for( int j = 1; j <= n; j++ ){
a[i][j] = ( a[i][j] + b[i][j] ) % MOD;
}
}
}
void get_matrix_pow( int a[MAX_N][MAX_N], int n ){
int ans[MAX_N][MAX_N] = {0};
int temp[MAX_N][MAX_N];
for( int i = 1; i <= N; i++ ) ans[i][i] = 1;
while( n ){
if( n & 1 ){
multipy( ans, N, N, a, N, N, temp );
memcpy( ans, temp, sizeof( int ) * MAX_N * MAX_N );
}
multipy( a, N, N, a, N, N, temp );
memcpy( a, temp, sizeof( int ) * MAX_N * MAX_N );
n /= 2;
}
memcpy( a, ans, sizeof( int ) * MAX_N * MAX_N );
}
void solve( int a[MAX_N][MAX_N], int b[MAX_N][MAX_N], int n ){
if( n == 1 ){
memcpy( a, b, sizeof( int ) * MAX_N * MAX_N );
return;
}
if( n == 0 ){
for( int i = 1; i <= N; i++ ){
for( int j = 1; j <= N; j++ ){
a[i][j] = 0;
}
}
for( int i = 1; i <= N; i++ ){
a[i][i] = 1;
}
return;
}
int temp1[MAX_N][MAX_N], temp2[MAX_N][MAX_N];
solve( a, b, n / 2 );
memcpy( temp1, b, sizeof( int ) * MAX_N * MAX_N );
get_matrix_pow( temp1, n / 2 );
multipy( temp1, N, N, a, N, N, temp2 );
add_matrix( a, temp2, N, N );
if( n & 1 ){
memcpy( temp1, b, sizeof( int ) * MAX_N * MAX_N );
get_matrix_pow( temp1, n );
add_matrix( a, temp1, N, N );
}
}
int main(){
int a[MAX_N][MAX_N], b[MAX_N][MAX_N];
while( scanf( "%d%d%d", &N, &K, &MOD ) != EOF ){
for( int i = 1; i <= N; i++ ){
for( int j = 1; j <= N; j++ ){
scanf( "%d", &b[i][j] );
}
}
solve( a, b, K );
for( int i = 1; i <= N; i++ ){
printf( "%d", a[i][1] );
for( int j = 2; j <= N; j++ ){
printf( " %d", a[i][j] );
}
printf( "\n" );
}
}
return 0;
}