题目链接:
题目大意:
给出一些字符串,不超过50个,取出m个,每种不限量,问能拼凑出多少种串
题目分析:
这道题很裸的矩阵快速幂,但是注意要去重,要不然会wa的很惨
首先既然是利用矩阵快速幂进行优化,那么一定是存在一个递推式的
dp[i][j]表示选取i个字符串且以第j种结尾的种类数
dp[i][j] = sum ( dp[i-1][k] ) (k是后面可以链接j的)
所以我们构建50*50的矩阵进行递推:
初始的矩阵定为:
1 1 1 .......1 1 1 1
0 0 0 .......0 0 0 0
..... ....
0 0 ..........0 0 0 0
转换矩阵是能够相连的字符串的邻接矩阵。
然后具体的看代码:
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#define MAX 51
using namespace std;
typedef long long LL;
int t,n,m;
const LL mod = 1000000007LL;
char s[MAX][15];
struct Matrix
{
int a[MAX][MAX];
};
void _reset ( Matrix& m )
{
memset ( m.a , 0 , sizeof ( m.a ));
}
void _set ( Matrix &m )
{
_reset ( m );
for ( int i = 0; i < n ; i++ )
m.a[i][i] = 1;
}
Matrix multi ( Matrix m1 , Matrix m2 )
{
Matrix ret;
_reset ( ret );
for ( int i = 0 ; i < n ; i++ )
for ( int j = 0 ; j < n ; j++ )
if ( m1.a[i][j] )
for ( int k = 0 ; k < n ; k++ )
{
ret.a[i][k] += ((LL)(m1.a[i][j])*(LL)(m2.a[j][k]))%mod;
ret.a[i][k] %= mod;
}
return ret;
}
Matrix quick ( Matrix m , int n )
{
Matrix ret;
_set ( ret );
while ( n )
{
if ( n&1 ) ret = multi ( ret , m );
m = multi ( m , m );
n >>= 1;
}
return ret;
}
void print ( Matrix m )
{
puts ( " -----------print the matrix----------- ");
for ( int i = 0 ; i < n ; i++ )
{
for ( int j = 0 ; j < n ; j++ )
printf ( "%d " , m.a[i][j] );
puts("");
}
puts ("------------print end matrix-------------" );
}
int main ( )
{
scanf ( "%d" , &t );
Matrix a,b;
while ( t-- )
{
scanf ( "%d%d" , &n , &m );
char ss[20];
int nn = 0;
for ( int i = 0 ; i < n ; i++ )
{
scanf ( "%s" , ss);
bool flag = false;
for ( int j = 0 ; j < nn ; j++ )
if ( !strcmp ( ss , s[j] ) )
flag = true;
if ( flag ) continue;
int len = strlen ( ss );
int j = 0;
while ( j < len )
{
s[nn][j] = ss[j];
j++;//必须写在这,写在上面应该写在前面,因为后面的先执行
}
s[nn++][j] = 0;
//cout << nn << " " << s[nn-1] << " " << ss << endl;
}
n = nn;
_reset ( a );
for ( int i = 0 ; i < n ; i++ )
a.a[0][i] = 1;
_reset ( b );
for ( int i = 0 ; i < n ; i++ )
for ( int j = 0 ; j < n ; j++ )
{
int len1 = strlen ( s[i] );
int len2 = strlen ( s[j] );
if ( len1 < 2 || len2 < 2 ) continue;
for ( int k = 0; k < len1-1 ; k++ )
{
bool flag = true;
int t;
for ( t = 0 ; k+t < len1 &&t < len2 ; t++ )
if ( s[i][k+t] != s[j][t] )
{
flag = false;
break;
}
if ( k+t != len1 ) flag = false;
if ( flag )
{
b.a[j][i] = 1;
break;
}
}
}
//print ( b );
b = quick ( b , m-1 );
a = multi ( a , b );
int ans = 0;
for ( int i = 0 ; i < n ; i++ )
{
ans += a.a[0][i];
ans %= mod;
}
//print ( a );
printf ( "%d\n" , ans );
}
}