定义
1.矩阵
在数学中,矩阵(Matrix)是一个按照长方阵列排列的实数或复数集合,元素是实数的矩阵称为实矩阵,元素是复数的矩阵称为复矩阵。而行数与列数都等于n的矩阵称为n阶矩阵或n阶方阵。n阶方阵中所有i=j的元素aij组成的斜线称为(主)对角线,所有i+j=n+1的元素aij组成的斜线称为辅对角线。
2.矩阵乘法
两个矩阵的乘法仅当第一个矩阵A的列数和第二个矩阵B的行数相等时才能定义(做乘法)。如A是m×n矩阵,B是n×p矩阵,它们的乘积C是一个m×p矩阵C=(cij)。
且公式为:
其实这些都是很基础的东西,看看就行了
但是:
这些就非常重要了
实现
既然有了定义,那么我们就可以去实现了
其实很简单,用结构体:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
const int MAXN = 103;
inline void read_(int &x) {
x=0;
int f=1;
char s=getchar();
while(s<'0'||s>'9') {
if(s=='-')
f=-1;
s=getchar();
}
while(s>='0'&&s<='9') {
x=x*10+s-48;
s=getchar();
}
x*=f;
}
struct node{
int s[MAXN][MAXN] , n , m;
void read( ){//结构体读入矩阵
for( int i = 1 ; i <= n ; i ++ )
for( int j = 1 ; j <= m; j ++ )
read_(s[i][j]);
}
void print(){//结构体输出矩阵
for( int i = 1 ; i <= n ; i ++ ){
for( int j = 1 ; j < m ; j ++ )
printf( "%d " , s[i][j] );
printf( "%d" , s[i][m] );
if( i < n )
printf( "\n" );
}
}
node operator * ( const node &a ){
node r;//r存结果
memset( r.s , 0 , sizeof( r.s ) );//以防万一
r.n = n , r.m = a.m;
for( int i = 1 ; i <= r.n ; i ++ )
for( int j = 1 ; j <= r.m; j ++ )
for( int k = 1 ; k <= m ; k ++ )
r.s[i][j] += s[i][k] * a.s[k][j];//公式的套用
return r;//最后记得返回值
}
}a , b , c;
int main(){
read_( a.n );read_( a.m );//a,b为加数,c为结果
b.n = a.m;//b的行与a的列必须相等
a.read();
read_( b.m );
b.read();
c = a * b;
c.print();
return 0;
}
如果博客只到这里,那么就太没有水平了~~
既然题目是矩阵加速,我们怎么进行加速呢?
其实网上说矩阵主要是有三个用途的,可是我只会一种~~
1.解线性方程组
2.方程降次
3.变换(比如什么平移、缩放、旋转、斜切的,也许以后会用到吧)
先来看一道很简单的题:
1.求斐波拉契的第n项
其实递推式很简单,但是如果这道题n超1e9怎么办呢
这时候就需要用到矩阵加速了
我们知道f[i] = f[i-1] + f[i-2]
于是我们定义矩阵A[f1 , f2]
矩阵B为
其实我们就可发现AB =
所以我们就可以求到了
这是我们需要用到:矩阵快速幂
矩阵A的b次方
node qpow( node a , int b ){//我们把矩阵只有主对角线为1的矩阵称作单位矩阵
node t ;
t.n = t.m = a.m;
t.s[1][1] = 1 , t.s[1][2] = 0 , t.s[2][1] = 0 , t.s[2][2] = 1;
while( b ){//其实和数的快速幂差不多
if( b % 2 ) t = t * a;//这里用到的是矩阵乘法
a = a * a;
b /= 2;
}
return t;
}
那么我们想:乘一个B矩阵就可以求到f3,那么乘n-2个B矩阵就可以得到fn了吧
还记得矩阵满足结合律吗?AB C = A (BC)
所以我们可以先算乘方,再算乘法
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;
#define ll long long
int n , mod;
struct node{
ll s[103][103] ;
int n , m;
void read( ){
for( int i = 1 ; i <= n; i ++ )
for( int j = 1 ; j <= m ; j ++ )
scanf( "%lld" , &s[i][j] );
}
void print( ){
for( int i = 1 ; i <= n ; i ++ ){
for( int j = 1 ; j < m ; j ++ )
printf( "%d " , s[i][j] );
printf( "%d" , s[i][m] );
if( i != n )
printf( "\n" );
}
}
node operator * ( const node & a ){
node r;
memset( r.s , 0 , sizeof( r.s ) );
r.n = n , r.m = a.m;
for( int i = 1 ; i <= r.n ; i += 1 ){
for( int j = 1 ; j <= r.m ; j ++ ){
for( int k = 1 ; k <= m ; k ++ ){
r.s[i][j] += s[i][k] * a.s[k][j];
if( r.s[i][j] > mod )
r.s[i][j] %= mod;
}
}
}
return r;
}
}a , b , c;
node qpow( node a , int b ){
node t ;
t.n = t.m = a.m;
t.s[1][1] = 1 , t.s[1][2] = 0 , t.s[2][1] = 0 , t.s[2][2] = 1;
while( b ){
if( b % 2 ) t = t * a;
a = a * a;
b /= 2;
}
return t;
}
int main(){
scanf( "%d%d" , &n , &mod );
a.n = 1 , a.m = 2;
b.n = b.m = 2;
a.s[1][1] = 1 , a.s[1][2] = 1;
b.s[1][1] = 0 , b.s[1][2] = 1 , b.s[2][1] = 1 , b.s[2][2] = 1;
c = qpow( b , n - 2 );
node ans = a * c;
printf( "%lld" , ans.s[1][2] );
return 0;
}
代码就没什么可讲的了