Matrix Power Series
Time Limit: 3000MS | Memory Limit: 131072K | |
Total Submissions: 24545 | Accepted: 10186 |
Description
Given a n × n matrix A and a positive integer k, find the sum S = A + A2 + A3 + … + Ak.
Input
The input contains exactly one test case. The first line of input contains three positive integers n (n ≤ 30), k (k ≤ 109) and m (m < 104). Then follow n lines each containing n nonnegative integers below 32,768, giving A’s elements in row-major order.
Output
Output the elements of S modulo m in the same way as A is given.
Sample Input
2 2 4 0 1 1 1
Sample Output
1 2 2 3
题意:给你 n k m( mod )
给你一个矩阵a ( n*n ) 你的任务是求出S。
思路: 构造矩阵A { a 1
0 1 }
由此我们可以知道A*A={a^2,a + 1
0 1 }
A*A*A={a^3 a^2+a+1
0 1 };
所以我们可以根据左上角 和右上角相加 再减去单位矩阵就可以得到 S
代码:
#include<stdio.h> #include<string.h> #include<iostream> #include<algorithm> #define N 65 using namespace std; struct node { int m[N][N]; }; node a,I; int n,k,mod; void input() { //左上角 for(int i=1;i<=n;i++) { for(int j=1;j<=n;j++) scanf("%d",&a.m[i][j]); } //左下角 for(int i=n+1;i<=2*n;i++) { for(int j=1;j<=n;j++) a.m[i][j]=0; } //右上角 for(int i=1;i<=n;i++) { for(int j=n+1;j<=2*n;j++) { if(j-n==i){ a.m[i][j]=1; } else a.m[i][j]=0; } } //右下角 for(int i=1+n;i<=2*n;i++) { for(int j=1+n;j<=2*n;j++){ if(i==j){ a.m[i][j]=1; } else a.m[i][j]=0; } } for(int i=1;i<=2*n;i++) { for(int j=1;j<=2*n;j++){ if(i==j) I.m[i][j]=1; else I.m[i][j]=0; } } return ; } node mut(node a,node b) { node ans; for(int i=1;i<=2*n;i++) { for(int j=1;j<=2*n;j++){ ans.m[i][j]=0; for(int k=1;k<=2*n;k++){ ans.m[i][j]=(ans.m[i][j]+a.m[i][k]*b.m[k][j])%mod; } } } return ans; } node solve(int kk) { node ans=I; while(kk) { if(kk&1) { ans=mut(ans,a); } a=mut(a,a); kk>>=1; } return ans; } void output(node x) { node fin; int i,j; for(i=1;i<=n;i++) { for(j=1;j<=n;j++){ fin.m[i][j]=0; fin.m[i][j]=(x.m[i][j]+fin.m[i][j])%mod; } } for(i=1;i<=n;i++) { for(j=n+1;j<=2*n;j++){ fin.m[i][j-n]=(x.m[i][j]+fin.m[i][j-n])%mod; } } for(i=1;i<=n;i++){ for(j=1;j<=n;j++) { if(i==j) fin.m[i][j]--; } } for(i=1;i<=n;i++) { for(j=1;j<=n;j++){ if(fin.m[i][j]<0){ fin.m[i][j]+=mod; } if(j==n){ printf("%d\n",fin.m[i][j]); } else printf("%d ",fin.m[i][j]); } } return ; } int main() { while(~scanf("%d %d %d",&n,&k,&mod)) { input(); node fin=solve(k); output(fin); } return 0; }