求:S = A + A2 + A3 + … +Ak.
第一次代码:二分直接递归,很好理解。1250MS
#include <cstdio>
#define SIZE (1<<10)
#define MAX_SIZE 32
#define lint __int64
using namespace std;
class CMatrix
{
public:
int element[MAX_SIZE][MAX_SIZE];
void setSize(int);
void setModulo(int);
CMatrix operator* (CMatrix);
CMatrix operator+ (CMatrix);
CMatrix power( lint );
CMatrix sum_exp ( lint );
private:
int size;
int modulo;
};
void CMatrix::setSize ( int x )
{
for ( int i = 0; i < x; i++ )
for ( int j = 0; j < x; j++ )
element[i][j]=0;
size = x;
}
void CMatrix::setModulo ( int x )
{
modulo = x;
}
CMatrix CMatrix::operator+ ( CMatrix param )
{
CMatrix ret;
ret.setSize ( size );
ret.setModulo ( modulo );
for ( int i = 0; i < size; i++ )
for ( int j = 0; j < size; j++ )
ret.element[i][j] = ( element[i][j] + param.element[i][j] ) % modulo;
return ret;
}
CMatrix CMatrix::operator* ( CMatrix param )
{
CMatrix product;
product.setSize ( size );
product.setModulo ( modulo );
for ( int i = 0; i < size; i++ )
for ( int j = 0; j < size; j++ )
for ( int k = 0; k < size; k++ )
{
product.element[i][j] += element[i][k]*param.element[k][j];
product.element[i][j] %= modulo;
}
return product;
}
CMatrix CMatrix::power( lint exp )
{
CMatrix tmp = (*this) * (*this);
if ( exp == 1 ) return *this;
else if (exp & 1) return tmp.power(exp/2) * (*this);
else return tmp.power( exp / 2 );
}
CMatrix CMatrix::sum_exp ( lint exp )
{
if ( exp == 1 ) return *this;
lint mid = exp / 2;
CMatrix tmps = sum_exp ( mid );
CMatrix tmpp = power ( mid );
if ( exp & 1 )
return tmps + tmpp * ( tmps + tmpp * (*this) );
else return tmps + tmpp * tmps;
}
int main()
{
lint n, k, m;
scanf("%I64d%I64d%I64d",&n,&k,&m);
CMatrix obj, res;
obj.setSize(n);
obj.setModulo(m);
int i, j;
for ( i = 0; i < n; i++ )
for ( j = 0; j < n; j++ )
scanf("%d",&obj.element[i][j]);
res = obj.sum_exp(k);
for ( i = 0; i < n; i++ )
{
for ( j = 0; j < n; j++ )
printf("%d ",res.element[i][j]);
printf("\n");
}
return 0;
}
第二次代码:
利用 直接对矩阵快速求幂。
但是比较纳闷还是 719MS
#include <cstdio>
#define SIZE 100
#define MAX_SIZE 100
using namespace std;
class CMatrix
{
public:
int element[MAX_SIZE][MAX_SIZE];
void setSize(int);
void setModulo(int);
CMatrix operator* (CMatrix);
CMatrix operator+ (CMatrix);
CMatrix power( int );
CMatrix sum_exp ( int );
private:
int size;
int modulo;
};
void CMatrix::setSize ( int x )
{
for ( int i = 0; i < x; i++ )
for ( int j = 0; j < x; j++ )
element[i][j]=0;
size = x;
}
void CMatrix::setModulo ( int x )
{
modulo = x;
}
CMatrix CMatrix::operator+ ( CMatrix param )
{
CMatrix ret;
ret.setSize ( size );
ret.setModulo ( modulo );
for ( int i = 0; i < size; i++ )
for ( int j = 0; j < size; j++ )
ret.element[i][j] = ( element[i][j] + param.element[i][j] ) % modulo;
return ret;
}
CMatrix CMatrix::operator* ( CMatrix param )
{
CMatrix product;
product.setSize ( size );
product.setModulo ( modulo );
for ( int i = 0; i < size; i++ )
for ( int j = 0; j < size; j++ )
for ( int k = 0; k < size; k++ )
{
product.element[i][j] += element[i][k]*param.element[k][j];
product.element[i][j] %= modulo;
}
return product;
}
/*
CMatrix CMatrix::power( int exp )
{
CMatrix tmp = (*this) * (*this);
if ( exp == 1 ) return *this;
else if (exp & 1) return tmp.power(exp/2) * (*this);
else return tmp.power( exp / 2 );
}*/
CMatrix CMatrix::power ( int exp )
{
CMatrix ret;
ret.setSize(size);
ret.setModulo(modulo);
CMatrix tmp = *this;
for ( int i = 0; i < size; i++ )
ret.element[i][i] = 1;
while ( exp >= 1 )
{
if ( exp & 1 )
ret = ret * tmp;
tmp = tmp * tmp;
exp >>= 1;
}
return ret;
}
CMatrix CMatrix::sum_exp ( int exp )
{
if ( exp == 1 ) return *this;
int mid = exp / 2;
CMatrix tmps = sum_exp ( mid );
CMatrix tmpp = power ( mid );
if ( exp & 1 )
return tmps + tmpp * ( tmps + tmpp * (*this) );
else return tmps + tmpp * tmps;
}
void print ( int n, CMatrix res )
{
for ( int i = 0; i < n; i++ )
{
for ( int j = 0; j < n; j++ )
printf("%d ",res.element[i][j]);
printf("\n");
}
}
int main()
{
int n, k, m;
scanf("%d%d%d",&n,&k,&m);
CMatrix obj, tmp, res;
obj.setSize(2*n);
obj.setModulo(m);
int i, j;
for ( i = 0; i < n; i++ )
for ( j = 0; j < n; j++ )
scanf("%d",&obj.element[i][j]);
if ( k == 1 ) { print( n, obj ); return 0; }
for ( i = 0; i < n; i++ )
obj.element[i+n][i+n] = obj.element[i][i+n] = 1;
for ( i = 0; i < n; i++ )
for ( j = 0; j < n; j++ )
tmp.element[i][j] = tmp.element[i+n][j] = obj.element[i][j];
res = obj.power(k-1) * tmp;
print ( n, res );
return 0;
}