Codeforces 86C Genetic engineering (AC自动机好题)
题意:给出一个字符串集合,总共有m个字符串,每个字符串长度不超过10。然后给出一个n,构造长度为n的串,这个串上的每一个字符,往前,往后延伸构成的若干个字符串中,至少有一个包含在字符集里的某一个字符串里面。问有多少种构造方法。。
解题思路:AC自动机。。这题还是比较难想的啊。首先我们可以定下两维状态,dp[i][j]表示构造长度为i的串,走到了j号节点。但是这样的状态显然不是最优的,因为我可以当前这个字符没匹配上,但是下一个在加入下一个字符的时候,构成的字符串把当前这个字符给匹配进去了。于是,可以加一维状态,可以理解为给未被匹配上的的字符预留一些长度。因此我定义的状态就是dp[i][j][k]表示长度为i,走到了j号节点,有k个字符还未被匹配上。状态定义好了,接下来就是怎么转移了。在自动机上,我记录了一个信息,val[i]表示如果走i这个节点,能匹配的最长的字符串长度是多少。那么转移方程就写成
for ( p = 0 ; p < 4 ; p ++ ) {int t = c[p][j] ;
if ( val[t] >= k + 1 )
dp[i+1][t][0] = ( dp[i+1][t][0] + dp[i][j][k] ) % mod ;
else dp[i+1][t][k+1] = ( dp[i+1][t][k+1] + dp[i][j][k] ) % mod ;
}
最后统计下和就好了。
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<queue>
#include<algorithm>
using namespace std ;
const int maxn = 333333 ;
int c[maxn*10] ;
const int mod = 1000000009 ;
int dp[1111][111][15] ;
char s[15] ;
int l[15] ;
int get ( char c ) {
if ( c == 'A' ) return 0 ;
if ( c == 'C' ) return 1 ;
if ( c == 'G' ) return 2 ;
return 3 ;
}
struct ac_auto {
queue<int> Q ;
int tot , c[4][111] ;
int val[111] , fail[111] ;
int new_node () {
int i ;
for ( i = 0 ; i < 4 ; i ++ ) c[i][tot] = 0 ;
fail[tot] = val[tot] = 0 ;
return tot ++ ;
}
void init () {
tot = 0 ;
new_node () ;
}
void insert ( char *s , int id ) {
int now = 0 ;
for ( ; *s ; s ++ ) {
int k = get ( *s ) ;
if ( !c[k][now] ) c[k][now] = new_node () ;
now = c[k][now] ;
}
val[now] = max ( val[now] , l[id] ) ;
}
void get_fail () {
int i , u = 0 ;
for ( i = 0 ; i < 4 ; i ++ )
if ( c[i][u] )
Q.push ( c[i][u] ) ;
while ( !Q.empty () ) {
u = Q.front () , Q.pop () ;
for ( i = 0 ; i < 4 ; i ++ ) {
if ( c[i][u] ) {
int e = c[i][u] ;
int j = fail[u] ;
fail[e] = c[i][j] ;
val[e] = max ( val[e] , val[fail[e]] );
// if ( e == 1 ) puts ( "fuck" ) ;
Q.push ( e ) ;
}
else c[i][u] = c[i][fail[u]] ;
// if ( u == 1 && i == 0 ) printf ( "fuck1 = %d\n" , c[0][1] ) ;
}
}
}
void solve ( int n ) {
int i , j , k , p ;
// printf ( "fuck %d\n" , c[0][1] ) ;
for ( i = 0 ; i < n ; i ++ ) {
for ( j = 0 ; j < tot ; j ++ )
for ( k = 0 ; k <= 10 ; k ++ ) {
if ( dp[i][j][k] ) {
// printf ( "dp[%d][%d][%d] = %d\n" , i , j , k , dp[i][j][k] ) ;
for ( p = 0 ; p < 4 ; p ++ ) {
int t = c[p][j] ;
if ( val[t] >= k + 1 )
dp[i+1][t][0] = ( dp[i+1][t][0] + dp[i][j][k] ) % mod ;
else dp[i+1][t][k+1] = ( dp[i+1][t][k+1] + dp[i][j][k] ) % mod ;
}
}
}
}
int ans = 0 ;
for ( i = 0 ; i < tot ; i ++ )
ans = ( ans + dp[n][i][0] ) % mod ;
printf ( "%d\n" , ans ) ;
}
} ac ;
int main () {
int n , m , i , j , k , p , r , t ;
while ( scanf ( "%d%d" , &n , &m ) != EOF ) {
ac.init () ;
j = 0 ;
while ( m -- ) {
scanf ( "%s" , s ) ;
l[++j] = strlen ( s ) ;
ac.insert ( s , j ) ;
}
ac.get_fail () ;
memset ( dp , 0 , sizeof ( dp ) ) ;
dp[0][0][0] = 1 ;
ac.solve ( n ) ;
}
}