题目分析:矩阵快速幂。首先我们知道 A^x 可以用矩阵快速幂求出来(具体可见poj 3070)。其次可以对k进行二分,每次将规模减半,分k为奇偶两种情况,如当k = 6和k = 7时有:
ps:对矩阵定义成结构体Matrix,求S时用递归,程序会比较直观,好写一点。当然定义成数组,然后再进行一些预处理,效率会更高些。
注意:
1.开始的时候,一直递归,层数太多,一直TLE,应该吧算出来的暂时存储的(见错误代码);
2.注意矩阵的0次幂
3.注意运算符重载
正确的代码:参考了http://blog.sina.com.cn/s/blog_6635898a0102e1am.html
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
int n,k,m;
struct node{
int matrix[50][50];
};
node a;
//运算符重载
node operator + (node x,node y)//矩阵x+矩阵y
{
node ans;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;
return ans;
}
node inline mult(node x,node y)//计算矩阵x*y
{
node c;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
int ans=0;
for(int p=1;p<=n;p++)//
{
ans+=(x.matrix[i][p]*y.matrix[p][j])%m;
ans%=m;
}
c.matrix[i][j]=ans%m;
}
return c;
}
node inline func(node x,int i)//计算矩阵x^i
{
//printf("%d**\n",i);
node temp,c;
memset(temp.matrix,0,sizeof(temp.matrix));
for(int j=1;j<=n;j++)
temp.matrix[j][j]=1;
if(i==0)
return temp;
if(i==1)
return x;
c=func(x,i/2);
if(i%2==0)
return mult(c,c);
else
return mult(mult(c,c),a);
}
node fun(node A,int x) //计算a^1+a^2+...+a^k
{
if(x==1)
return A;
node B=func(A,(x+1)/2);
node C=fun(A,x/2);
if(x%2==0)
return mult((func(A,0)+B),C);//return B+mult(C,B);
else
return A+mult((A+B),C);//B+mult(C,B)+C;
}
int main()
{
while(scanf("%d %d %d",&n,&k,&m)!=EOF)
{
int i,j;
for(i=1;i<=n;i++)
for(j=1;j<=n;j++)
scanf("%d",&a.matrix[i][j]);
node ans=fun(a,k);
for(i=1;i<=n;i++)
{
printf("%d",ans.matrix[i][1]);
for(j=2;j<=n;j++)
printf(" %d",ans.matrix[i][j]);
printf("\n");
}
}
//system("pause");
return 0;
}
错误的代码:
k = 6 有: S(6) = (1 + A^3) * (A + A^2 + A^3) = (1 + A^3) * S(3)。
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
int n,k,m;
struct node{
int matrix[50][50];
bool flag;
}arr[11000];
node a;
//运算符重载
node operator + (node x,node y)//矩阵x+矩阵y
{
node ans;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
ans.matrix[i][j]=(x.matrix[i][j]+y.matrix[i][j])%m;
return ans;
}
node operator = (node x)
{
/*node ans;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
ans.matrix[i][j]=x.matrix[i][j];*/
return x;
}
node mult(node x,node y)//计算矩阵x*y
{
node c;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
{
int ans=0;
for(int p=1;p<=n;p++)//
{
ans+=(x.matrix[i][p]*y.matrix[p][j])%m;
ans%=m;
}
c.matrix[i][j]=ans%m;
}
return c;
}
node func(node x,int i)//计算矩阵x^i
{
//printf("%d**\n",i);
if(i==1)
return x;
if(i%2==0)
return mult(func(x,i/2),func(x,i/2));
else
return mult(mult(func(x,i/2),func(x,i/2)),a);
}
node fun(int x) //计算a^1+a^2+...+a^k
{
node temp;
if(arr[x].flag==true)
return arr[x];
if(x%2==0)
temp=fun(x/2)+mult(func(a,x/2),fun(x/2));
else
temp=fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x);
arr[x]=temp;
arr[x].flag=true;
return temp;
/*if(x==1)
return a;
if(x%2==0)
return fun(x/2)+mult(func(a,x/2),fun(x/2));
else
return fun(x/2)+mult(func(a,x/2),fun(x/2))+func(a,x);*/
}
int main()
{
while(scanf("%d %d %d",&n,&k,&m)!=EOF)
{
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
scanf("%d",&a.matrix[i][j]);
for(int i=1;i<=10000;i++)
arr[i].flag=false;
arr[1]=a;
arr[1].flag=true;
node ans=fun(k);
for(int i=1;i<=n;i++)
{
printf("%d",ans.matrix[i][1]);
for(int j=2;j<=n;j++)
printf(" %d",ans.matrix[i][j]);
printf("\n");
}
}
system("pause");
return 0;
}