题意:给出一个字符集V和P个模式串(长度小于10),问由这个字符集中字符组成的长度为N的且不包含任意一个模式串的字符串有多少个?(字符集大小,N<=50, P <= 10) 。
思路:先将P个模式串建立AC自动机,标记好危险节点(flag数组)。然后动归来求:dp[i][j]表示长度为i且最后在节点j的字符串个数(节点j必为安全节点),初始dp[0][1] = 1, 其他dp[i][j] = 0。由dp[i][j] 可以导出,每个由j可以到达的安全节点son[j],执行:dp[i+1][son[j]] += dp[i][j]。因为从根走i步到达节点j有n种走法,那么走i+1步到达son[j]的走法就要加n。最终的答案为∑{dp[N][j] | j是安全节点}。
最后的数量很大,需要用数组存数。
需要注意:如果是用指针来存储,那么son[j]不仅指从j通过一条字母边直接到达的son[j], 也可以是通过若干前缀指针后再通过一个字母边到达son[j],(即son[j]并不一定是 j 的子节点)。而用数组存储恰恰避免了这一点。
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <queue>
#include <cstdlib>
using namespace std;
#define INF 0x3fffffff
#define N 505
#define M 55
int n,m,p;
int t[N][M],fail[N],top;
bool flag[N];
char word[M],s[M];
struct dp{
int num[100];
bool has;
}dp[M][N];
int res[100];
map<char, int> hh;
queue<int> q;
void init(){
int i;
memset(t, -1, sizeof(t));
memset(fail, 0, sizeof(fail));
memset(flag, false, sizeof(flag));
for(i = 0;i<n;i++)
t[0][i] = 1;
top = 1;
hh.clear();
for(i = 0;word[i];i++)//字母表和0...n-1的对应
hh[word[i]] = i;
}
void insert(char* s){
int i,r = 1;
for(i = 0;s[i];i++){
if(t[r][hh[s[i]]] == -1)
t[r][hh[s[i]]] = (++top);
r = t[r][hh[s[i]]];
}
flag[r] = true;
}
void buildDFA(){
int i,now;
q.push(1);
while(!q.empty()){
now = q.front();
q.pop();
for(i = 0;i<n;i++){
if(t[now][i] == -1)
t[now][i] = t[fail[now]][i];
else{
fail[t[now][i]] = t[fail[now]][i];
q.push(t[now][i]);
if(flag[t[fail[now]][i]])//危险节点建立好
flag[t[now][i]] = true;
}
}
}
}
void add(int* a,int* b){//大数相加,把b加到a
int i,j;
for(i = j = 0;i<100;i++){
a[i] += b[i]+j;
j = a[i]/10;
a[i] %= 10;
}
}
int main(){
int i,j,k;
while(scanf("%d %d %d\n",&n,&m,&p)!=EOF){
gets(word);
init();
for(i = 1;i<=p;i++){
gets(s);
insert(s);
}
buildDFA();
for(i = 0;i<=m;i++)
for(j = 1;j<=top;j++){
memset(dp[i][j].num, 0, sizeof(dp[i][j].num));
dp[i][j].has = false;
}
dp[0][1].num[0] = 1;
dp[0][1].has = true;
for(i = 0;i<m;i++)
for(j = 1;j<=top;j++)
for(k = 0;k<n;k++)
if(dp[i][j].has && !flag[t[j][k]]){
add(dp[i+1][t[j][k]].num , dp[i][j].num);
dp[i+1][t[j][k]].has = true;
}
memset(res, 0, sizeof(res));
for(i = 1;i<=top;i++)
add(res,dp[m][i].num);
for(i = 99;i>=0&&!res[i];i--);
if(i==-1)
putchar('0');
for(;i>=0;i--)
printf("%d",res[i]);
putchar('\n');
}
return 0;
}