最近做了几个题都关于矩阵快速幂,所以把这部分都好好的看了看。
首先要知道什么是快速幂。如果想算一个数字a的9次方,常规的想法可能是算a*a*a*a*a*a*a*a*a,也就是乘以九次a。很明显,这个算法的复杂度是O(n),怎么能加快呢,算a的九次方,先算a的四次方,再平方这个结果,再乘以a。就是这样:(a^4)*(a^4)*a;,然后呢,a^4可以写成a^2*a^2,a^2写成a*a,这样子分解下来,算法是非常快的,复杂度是O(lg(n)),大家也知道,这俩复杂度相差不是一点半点。所以,是值得学习的一种操作。
那这部分在实现时,怎么实现呢,还借助了位运算,n>>=1 ,这个是什么意思呢,就是说,n这个数字的二进制,向右移了一位,比如n=9=1001,执行之后就是n=4=100,
实现代码如下:
void Pow(ll a[][N],ll n)
{
memset(res,0,sizeof res);//n是幂,N是矩阵大小
for(int i=0;i<N;i++) res[i][i]=1;
while(n)
{
if(n&1)
multi(res,a,N);//res=res*a;复制直接在multi里面实现了;
multi(a,a,N);//a=a*a
n>>=1;
}
}
其中res矩阵是初始矩阵,也是最终的结果矩阵,mulit这个函数是矩阵相乘,可以先不用管它。关键那个判断句,如果当然数字是奇数,就把a乘到结果矩阵里,若不是,则不乘,然后,不管是否为奇数,a都和自身相乘。我一开始没看懂这个操作,后来拿笔计算一下,就明白了。建议大家也拿笔算一下。快速幂就是这样。
矩阵快速幂,就是数字a不再是个数字了,是个矩阵,所以我们需要先知道矩阵乘法,这个我就不说了,线性代数书上也有,百度也一大把。自己搜。矩阵相乘没什么好说的,直接上全部代码。
#define ll long long int
const ll mod=1000000009;
const ll N=3;
ll tmp[N][N];
void multi(ll a[][N],ll b[][N],ll n)
{
memset(tmp,0,sizeof tmp);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
tmp[i][j]=(tmp[i][j]+(a[i][k]*b[k][j])%mod)%mod;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
a[i][j]=tmp[i][j];
}
ll res[N][N];
void Pow(ll a[][N],ll n)
{
memset(res,0,sizeof res);//n是幂,N是矩阵大小
for(int i=0;i<N;i++) res[i][i]=1;
while(n)
{
if(n&1)
multi(res,a,N);//res=res*a;复制直接在multi里面实现了;
multi(a,a,N);//a=a*a
n>>=1;
}
}
很清晰的代码结构。没什么好说的 ,有疑问再交流。
这些都是模板,关键的是,怎么用这个呢。一搜矩阵快速幂,就有斐波那契的例子。这个是非常经典,如图:
一看系数矩阵,往前面的模板一套,就出结果了,爽是爽,但下次遇到,不会用就凉了,所以,还是要搞懂,这个系数矩阵怎么来的。我在理解的时候,都是用高中的数列去套的,怎么套。现在相当于找等比数列的公比,直接给的f(n),并不一定是等比数列,怎么办,就是先构造一个等比数列,这个等比数列里包含了f(n).那么就可以借助这个新的数列,去求f(n)的通项。首先,题目一般式给了f(n)=af(b-1)+bf(n-2),那这个在构造时,就以f(n+1)为等比数列的A(n+1),去找A(n)乘以一个什么样的矩阵,得到这个A(n),这个是比较容易的,[a,b][1,0],我不会打矩阵,我就一行一行写了。可以发现,f(n)乘以这个矩阵是等于f(n+1)的,就目前来说,我碰到的简单题大部分都是这样算出来的。就是可以f(n+1)写左边,f(n)写右边,去找系数矩阵。这个51Nod上有个题就可以去做的,链接:http://www.51nod.com/onlineJudge/questionCode.html#!problemId=1126
非常的典型;然后稍微难一点的就是f(n)这个式子后面还有一串,比如这个题:https://www.nowcoder.com/acm/contest/105/G
这个怎么去构造,其实都是一样的,也是先把f(n+1),写左边,f(n)写右边,然后,在算系数是时,记得拆开(n+1)的次方时,注意系数别写错了。手写两遍,就懂了。这个系数矩阵是这样的:
ll a[6][6]={
1,1,1,1,1,1,
1,0,0,0,0,0,
0,0,1,3,3,1,
0,0,0,1,2,1,
0,0,0,0,1,1,
0,0,0,0,0,1
};
只要会推这个系数矩阵,剩下就是套模板。当然也有模板失手的时候,比如这个题,就不能简单的套模板:我也顺便写了题解,来看看吧:点击打开链接
附上文中提到的题的ac代码:
51nod:
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
const int mod=7;
const int N=2;
int tmp[N][N];
void multi(int a[][N],int b[][N],int n)
{
memset(tmp,0,sizeof tmp);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
tmp[i][j]=(tmp[i][j]+((a[i][k]*b[k][j])%mod+7)%mod)%mod;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
a[i][j]=tmp[i][j];
}
int res[N][N];
void Pow(int a[][N],int n)
{
memset(res,0,sizeof res);//n是幂,N是矩阵大小
for(int i=0;i<N;i++) res[i][i]=1;
while(n)
{
if(n&1)
multi(res,a,N);//res=res*a;复制直接在multi里面实现了;
multi(a,a,N);//a=a*a
n>>=1;
}
}
int main()
{
#ifdef LOCAL
freopen("D:/input.txt" , "r", stdin);
#endif
int a[2][2];
int n=5;
int A, B, N;
cin>>A>>B>>N;
// for(int i=3;i<10;i++)
// {
memset(a,0,sizeof(a));
a[0][0]=A;
a[0][1]=B;
a[1][0]=1;
a[1][1]=0;
Pow(a,N-2);
cout<<(res[0][1]+res[0][0])%mod<<endl;
// }
return 0;
}
牛客网:
#include<bits/stdc++.h>
using namespace std;
#define ll long long int
const ll mod=1000000007;
const ll N=6;
ll tmp[N][N];
void multi(ll a[][N],ll b[][N],ll n)
{
memset(tmp,0,sizeof tmp);
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
for(int k=0;k<n;k++)
tmp[i][j]=(tmp[i][j]+(a[i][k]*b[k][j])%mod)%mod;
for(int i=0;i<n;i++)
for(int j=0;j<n;j++)
a[i][j]=tmp[i][j];
}
ll res[N][N];
void Pow(ll a[][N],ll n)
{
memset(res,0,sizeof res);//n是幂,N是矩阵大小
for(int i=0;i<N;i++) res[i][i]=1;
while(n)
{
if(n&1)
multi(res,a,N);//res=res*a;复制直接在multi里面实现了;
multi(a,a,N);//a=a*a
n>>=1;
}
}
int main()
{
#ifdef LOCAL
freopen("D:/input.txt" , "r", stdin);
#endif
// ll a[N][N];
ll n, m;
int t;
cin>>t;
while(t--)
{
ll a[6][6]={
1,1,1,1,1,1,
1,0,0,0,0,0,
0,0,1,3,3,1,
0,0,0,1,2,1,
0,0,0,0,1,1,
0,0,0,0,0,1
};
cin>>n;
if(n==0)
{
cout<<0<<endl;
continue;
}
else if(n==1)
{
cout<<1<<endl;
continue;
}
memset(res,0,sizeof(res));
Pow(a,n-1);
ll sum=0;
int qwe[6]={1,0,8,4,2,1};
for(int i=0;i<6;i++)
{
sum=(sum+(res[0][i]*qwe[i])%mod)%mod;
}
cout<<sum%mod<<endl;
}
return 0;
}