【矩阵加速】
矩阵是什么呢,就是一个由n * m的方阵的二维数组,那么,我们的矩阵到底有什么用呢?且听分解:
相信大家都听说过矩阵加速的吧,矩阵加速就是通过快速幂的方法,将O(n)的时间复杂度降到O(logn),在一些题目中可以起到意想不到的作用!
矩阵乘法
首先要想使用矩阵快速幂,我们就需要知道什么是矩阵乘法。
看这里,百度解释
若是无法忍受百度的说法的话,我精简了一下,矩阵乘法就是:
设A为 n * m 的矩阵,B为 m * p 的矩阵,那么称 n * p 的矩阵C为矩阵A与B的乘积
所以实际上矩阵乘法是一个O(n^3)的算法,那么也就是说矩阵加速在 n 过于大的时候,是不适用的,接下来,我们来看一下,如何进行矩阵加速。
矩阵加速
就像这个名字一样,矩阵加速是用来加速的,那么我们自然而然地就会使用一个加速矩阵。
用题目来理解吧:
e.g.设Fibonacci的第i位为f(i), 有:f(1) = 1, f(2) = 1, f(n) = f(n - 1) + f(n - 2), 给你一个n和mod,求f(n) % mod 的值。
首先我们设一个初始矩阵Origin[1][2] = {1, 1}
代表的分别是f(1), f(2)
那么加速矩阵怎么写呢
可不可以这么想
因为我们经过一次乘法了过后Origin数组就是 {f(2), f(3)}, 且
f(3) = f(1) + f(2)
所以是不是应该推出如此的一个加速矩阵
speed_up[2][2] =
{{0, 1},
{1, 1}};
通过这个矩阵我们便可以将f(1) 变为 f(2), f(2) 变为 f(3)了
延伸一下,我们可以通过乘加速矩阵 n 次使得Origin数组变为f(n + 1), f(n + 2)
那么就可以将加速矩阵进行快速幂处理,再与初始矩阵相乘即可
这便是矩阵加速的精髓了
代码
这里给一下e.g.的代码:
#include <cctype>
#include <cstdio>
#include <cstring>
typedef long long ll;
template <class T>
void r(T &x)
{
#define gc getchar()
x = 0;
char c = gc;
int f = 1;
while (!isdigit(c)) {if (c == '-') f = -1;c = gc;}
while (isdigit(c)) x = (x << 3) + (x << 1) + (c ^ 48), c = gc;
x *= f;
#undef gc
}
const int N = 5;
int n, mod;
struct matrix
{
ll a[N][N];
int n, m;
matrix()
{
memset (a, 0, sizeof a);
n = m = 0;
}
void read()
{
for (int i = 1;i <= n; i++)
for (int j = 1;j <= m; j++)
r(a[i][j]);
}
void print()
{
for (int i = 1;i <= n; i++)
{
for (int j = 1;j < m; j++)
printf ("%d ", a[i][j]);
printf ("%d\n", a[i][m]);
}
}
matrix operator * (matrix x)const
{
matrix ret;
ret.n = n, ret.m = x.m;
for (int i = 1;i <= n; i++)
for (int j = 1;j <= x.m; j++)
for (int k = 1;k <= m; k++)
ret.a[i][j] = (ret.a[i][j] + a[i][k] * x.a[k][j]) % mod;
return ret;
}
friend matrix qkpow(matrix x, int b)
{
matrix ret;
ret.n = ret.m = x.n;
for (int i = 1;i <= x.n; i++)
ret.a[i][i] = 1;
///这个矩阵与任何矩阵相乘都是任何矩阵,类似于快速幂中的 1
while (b)
{
if (b & 1)
ret = x * ret;
x = x * x;
b >>= 1;
}
return ret;
}
}speed_up, origin;
int main()
{
r(n), r(mod);
origin.m = speed_up.n = speed_up.m = 2;
origin.n = 1;
origin.a[1][1] = origin.a[1][2] = speed_up.a[2][1] = speed_up.a[1][2] = speed_up.a[2][2] = 1;
origin = origin * qkpow(speed_up, n - 2);
printf ("%d", origin.a[1][2]);
return 0;
}