hdu 4878 ZCC loves words(AC自动机+dp+矩阵快速幂+中国剩余定理)
题意:给出若干个模式串,总长度不超过40,对于某一个字符串,它有一个价值,对于这个价值的计算方法是这样的,设初始价值为V=1,假如这个串能匹配第k个模式串,则V=V*prime[k]*(i+len[k]),其中prime[k]表示第k个素数,i表示匹配的结束位置,len[k]表示第k个模式串的长度(注意,一个字符串可以多次匹配同意个模式串)。问字符集为'A'-'Z'的字符,组成的所有的长为L的字符串,能得到的总价值和是多少?
解法:跟以前做过的很多AC自动机的题有点类似,很容易想到一个node*L的dp,dp[i][v]表示长为i,匹配到AC自动机的V节点能得到的价值和(详见代码中的DEBUG函数)。但是L太大,没法搞。节点总数只有40,那么就可以用矩阵来加速dp了,但是很可惜,建立矩阵的时候,发现建的矩阵居然是跟i有关,这样是不能直接用矩阵快速幂做的。但是,题目给出的提示是,mod可以拆成三个较小的质数。那么我们可以分别用三个较小的质数作为mod进行运算,因为第i个矩阵,它是跟第i+mod个矩阵一样的,所以我们可以把L个矩阵分成L/mod段,每一段的矩阵乘起来都是一样的,设其为A(可以暴力乘起来,因为mod很小),那么我们要的所有的L个矩阵的乘起来得到的矩阵,就是A^(L/mod),再乘上剩下来多余的L%mod个了,这样就可以计算出在每个较小的模系下的答案。最后用中国剩余定理计算总的答案。
代码:
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<vector>
#include<queue>
#define ll __int64
using namespace std ;
const int N = 44 ;
const int mod = 5047621 ;
int pri[12345] , p_num , vis[12345] ;
void get_prime () {
p_num = 0 ;
for ( int i = 2 ; i < 12345 ; i ++ ) {
if ( !vis[i] ) pri[++p_num] = i ;
for ( int j = 1 ; j <= p_num ; j ++ ) {
if ( i * pri[j] >= 12345 ) break ;
vis[i*pri[j]] = 1 ;
if ( i % pri[j] == 0 ) break ;
}
}
}
struct Point {
int len , p ;
Point () {}
Point ( int a , int b ):len(a),p(b) {}
} ;
struct RECT {
int elem[N][N] ;
void print ( int n ) {
for ( int i = 0 ; i < n ; i ++ , puts ( "" ) )
for ( int j = 0 ; j < n ; j ++ )
printf ( "%d " , elem[i][j] ) ;
}
} p[222] , E ;
struct AC_auto {
int dp[111][44] ;
int c[N][26] , fail[N] , tot ;
vector<Point> vec[N] ;
queue<int> Q ;
void init () {
tot = 0 ;
new_node () ;
}
int new_node () {
vec[tot].clear () ;
fail[tot] = 0 ;
memset ( c[tot] , 0 , sizeof ( c[tot] ) ) ;
return tot ++ ;
}
void insert ( char *s , int i ) {
int now = 0 , len = strlen ( s ) ;
for ( ; *s ; s ++ ) {
int k = *s - 'A' ;
if ( !c[now][k] ) c[now][k] = new_node () ;
now = c[now][k] ;
}
vec[now].push_back ( Point ( len , pri[i] ) ) ;
}
void get_fail () {
int u = 0 , v ;
for ( int i = 0 ; i < 26 ; i ++ ) {
if ( c[u][i] )
Q.push ( c[u][i] ) ;
}
while ( !Q.empty () ) {
u = Q.front () ; Q.pop () ;
for ( int i = 0 ; i < 26 ; i ++ ) {
if ( c[u][i] ) {
v = c[u][i] ;
fail[v] = c[fail[u]][i] ;
Q.push ( v ) ;
} else c[u][i] = c[fail[u]][i] ;
}
}
}
void BUILD_RECT ( int l , int mod ) {
memset ( p[l].elem , 0 , sizeof ( p[l].elem ) ) ;
for ( int i = 0 ; i < tot ; i ++ ) {
for ( int j = 0 ; j < 26 ; j ++ ) {
int u = c[i][j] ;
int v = u , ret = 1 ;
while ( v ) {
for ( int k = 0 ; k < vec[v].size () ; k ++ ) {
Point u = vec[v][k] ;
ret *= (l+u.len)*u.p ;
ret %= mod ;
}
v = fail[v] ;
}
p[l].elem[u][i] += ret ;
if ( p[l].elem[u][i] >= mod )
p[l].elem[u][i] -= mod ;
}
}
}
void RECT_MUIL ( RECT x , RECT y , RECT &z , int mod ) {
memset ( z.elem , 0 , sizeof ( z.elem ) ) ;
for ( int i = 0 ; i < tot ; i ++ ) {
for ( int j = 0 ; j < tot ; j ++ )
for ( int k = 0 ; k < tot ; k ++ ) {
z.elem[i][j] += x.elem[i][k] * y.elem[k][j] % mod ;
if ( z.elem[i][j] >= mod )
z.elem[i][j] -= mod ;
}
}
}
void GAO ( RECT& ret , ll n , int mod ) {
// printf ( "n = %I64d\n" , n ) ;
RECT f = ret ; ret = E ;
while ( n ) {
if ( n & 1 ) RECT_MUIL ( ret , f , ret , mod ) ;
RECT_MUIL ( f , f , f , mod ) ;
n >>= 1 ;
}
}
int SOLVE ( int mod , ll l ) {
RECT ans = E , temp = E ;
// printf ( "mod = %d\n" , mod ) ;
for ( int i = mod ; i >= 1 ; i -- ) {
BUILD_RECT ( i , mod ) ;
RECT_MUIL ( temp , p[i] , temp , mod ) ;
// if (i == 1) ans.print ( tot ) ;
}
// puts( "fuck ") ;
GAO ( temp , l/mod , mod ) ;
// ans.print ( tot ) ;
for ( int i = l % mod ; i >= 1 ; i -- ) {
BUILD_RECT ( i , mod ) ;
RECT_MUIL ( ans , p[i] , ans , mod ) ;
}
RECT_MUIL ( ans , temp , ans , mod ) ;
// ans.print ( tot ) ;
int ret = 0 ;
for ( int i = 0 ; i < tot ; i ++ ) {
ret += ans.elem[i][0] ;
if ( ret >= mod ) ret -= mod ;
}
return ret ;
}
void DEBUG ( ll l ) {
memset ( dp , 0 , sizeof ( dp ) ) ;
dp[0][0] = 1 ;
for ( int i = 0 ; i < l ; i ++ ) {
for ( int j = 0 ; j < tot ; j ++ ) {
for ( int k = 0 ; k < 26 ; k ++ ) {
int u = c[j][k] ;
int v = u ;
int ret = 1 ;
while ( v ) {
for ( int g = 0 ; g < vec[v].size () ; g ++ ) {
Point f = vec[v][g] ;
ret *= (i+1+f.len) * f.p ;
}
v = fail[v] ;
}
dp[i+1][u] += dp[i][j] * ret % mod ;
if ( dp[i+1][u] >= mod ) dp[i+1][u] -= mod ;
}
}
}
int ans = 0 ;
for ( int i = 0 ; i < tot ; i ++ ) {
ans += dp[l][i] ;
if ( ans >= mod ) ans -= mod ;
}
puts ( "fuck" ) ;
printf ( "%d\n" , ans ) ;
}
} ac ;
void extend_gcd ( ll a , ll b , int &x , int &y ) {
if ( !b ) x = 1 , y = 0 ;
else extend_gcd ( b , a % b , y , x ) , y -= x * ( a / b ) ;
}
char s[1111] ;
int main () {
for ( int i = 0 ; i < N ; i ++ )
for ( int j = 0 ; j < N ; j ++ )
E.elem[i][j] = i == j ;
get_prime () ;
int n ; ll l ;
int ca = 0 ;
while ( scanf ( "%d%I64d" , &n , &l ) != EOF ) {
ac.init () ;
for ( int i = 1 ; i <= n ; i ++ ) {
scanf ( "%s" , s ) ;
ac.insert ( s , i ) ;
}
ac.get_fail () ;
// ac.DEBUG ( l ) ;
int m1 , mm1 , m2 , mm2 , m3 , mm3 , fuck ;//mm为m的乘法逆元
m1 = 173 * 179 , m2 = 163 * 179 , m3 = 163 * 173 ;
extend_gcd ( m1 , 163 , mm1 , fuck ) ;
extend_gcd ( m2 , 173 , mm2 , fuck ) ;
extend_gcd ( m3 , 179 , mm3 , fuck ) ;
int a1 = ac.SOLVE ( 163 , l ) ;
// printf( "a1 = %d\n" , a1 ) ;
int a2 = ac.SOLVE ( 173 , l ) ;
// printf ( "a2 = %d\n" , a2 ) ;
int a3 = ac.SOLVE ( 179 , l ) ;
// printf ( "a3 = %d\n" , a3 ) ;
int ans = ( a1 * m1 * mm1 + a2 * m2 * mm2 + a3 * m3 * mm3 ) % 5047621 ;
printf ( "Case #%d: %d\n" , ++ ca , ans ) ;
}
return 0 ;
}
/*
2 3
AB
BB
2 2
A
B
*/