矩阵相乘:
先看一下矩阵乘法的代码(题目链接)
#include<stdio.h>
#include<string.h>
using namespace std;
const int maxn=100+7;
int a[maxn][maxn],b[maxn][maxn],c[maxn][maxn];
int main()
{
int n,i,j,k,sum;
while(~scanf("%d",&n))
{
memset(c,0,sizeof(c));
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
scanf("%d",&a[i][j]);
}
}
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
scanf("%d",&b[i][j]);
}
}
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
for( k=0;k<n;k++)
{
c[i][j]+=a[i][k]*b[k][j];
}
}
}
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
printf("%d ",c[i][j]);
}
printf("\n");
}
}
return 0;
}
上述代码时间复杂度为o(n^3)。
普通的快速幂(求x^n,res即为所求):
ll pow_mod(ll x,ll n){
ll res=1;
while(n){
if(n&1)
res=res*x;
x=x*x;
n>>=1;
}
return res;
}
矩阵快速幂可以把每个矩阵看成一个数,代入上面的普通的快速幂代码中,res=1,就变成单位矩阵了(主对角线上都为1,其它都为0的矩阵)。
代码实现:
void multi(ll a[][N],ll b[][N],int n)
{
memset(tmp,0,sizeof(tmp));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
for(int k=0;k<n;k++)
{
tmp[i][j]+=(a[i][k]*b[k][j])%mod;//根据题意而定需不需要模mod
}
tmp[i][j]=tmp[i][j]%mod;
}
}
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
a[i][j]=tmp[i][j];
}
void Pow(ll a[][N],ll m,int n)
{
memset(res,0,sizeof(res));//m是幂,n是矩阵大小
for(int i=0;i<n;i++) res[i][i]=1;//单位矩阵
while(m)
{
if(m&1)
multi(res,a,n);//res=res*a;复制直接在multi里面实现了;
multi(a,a,n);//a=a*a
m>>=1;
}
}
上述代码就相当于求得a^m的过程(a为矩阵),普通快速幂中的res是返回一个值,我们这里可以把res矩阵定义成全局变量,或通过传参,以达到在主函数中是被改变后的res矩阵(res矩阵即为我们要求的a^m,a为矩阵)。
下面给出整个过程的完整代码(题目链接):
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int N=100;
ll tmp[N][N],res[N][N];
void multi(ll a[][N],ll b[][N],int n)
{
memset(tmp,0,sizeof(tmp));
for(int i=0;i<n;i++)
{
for(int j=0;j<n;j++)
{
for(int k=0;k<n;k++)
{
tmp[i][j]+=(a[i][k]*b[k][j])%mod;
}
tmp[i][j]=tmp[i][j]%mod;
}
}
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
a[i][j]=tmp[i][j];
}
void Pow(ll a[][N],ll m,int n)
{
memset(res,0,sizeof(res));//m是幂,n是矩阵大小
for(int i=0;i<n;i++) res[i][i]=1;//单位矩阵
while(m)
{
if(m&1)
multi(res,a,n);//res=res*a;复制直接在multi里面实现了;
multi(a,a,n);//a=a*a
m>>=1;
}
}
int main()
{
int n,i,j;
ll a[N][N],m;
while(~scanf("%d%lld",&n,&m))
{
memset(a,0,sizeof(a));
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
scanf("%lld",&a[i][j]);
}
}
Pow(a,m,n);
for(i=0;i<n;i++)
{
for(j=0;j<n;j++)
{
printf("%lld ",res[i][j]);
}
printf("\n");
}
}
return 0;
}