POJ 1625 Censored!(AC自动机+DP)
http://poj.org/problem?id=1625
题意:
给你由特定N个字符组成的P个模板和长度M,问你由这特定N个字符组成的长为M的文本串不包含任意一个模板有多少种情况?M<=50
分析:
由于M<=50,所以直接用DP做,不用矩阵幂算.本题很类似UVA11468:
http://blog.csdn.net/u013480600/article/details/23294375
我们令d[i][j]=x表示当前在i号节点,还有j步要走且不经过后缀单词节点的情况总数为x.
初值d[i][0]=1. i为非单词节点
d[i][j] = sum(d[k][j-1])其中从i可以走到k,且i和k节点都不是后缀单词节点.
最后我们所求为d[0][m]。(用本解法可以使用记忆话搜索,因为我们很容易知道d[i][j]的所有依赖项d[k][j-1])
或者可以这么推(程序中用的就是该方式):
令d[i][j]=x表示当前在i点,已经走过了j距离的情况总数为x.
d[i][j]=sum(d[k][j-1])从k可以走到i,且k和i都是非单词节点.
初值为d[0][0]=1.其他都为0,然后用滚动数组递推即可。(用本节点只能使用递推刷新求DP,因为我们无法知道d[i][j]的所有依赖项d[k][j-1])
AC代码:
#include<iostream>
#include<queue>
#include<map>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
const int maxnode=100+20;
int sigma_size;//后面读入
struct AC_Automata
{
int ch[maxnode][50+20];
int match[maxnode];
int f[maxnode];
map<char ,int> mp;
int sz;
void init()
{
sz=1;
memset(ch[0],0,sizeof(ch[0]));
match[0]=f[0]=0;
mp.clear();
}
void insert(char *s)
{
int n=strlen(s),u=0;
for(int i=0;i<n;i++)
{
int id=mp[s[i]];
if(ch[u][id]==0)
{
ch[u][id]=sz;
memset(ch[sz],0,sizeof(ch[sz]));
match[sz++]=0;
}
u=ch[u][id];
}
match[u]=1;
}
void getFail()
{
f[0]=0;
queue<int> q;
for(int i=0;i<sigma_size;i++)
{
int u=ch[0][i];
if(u)
{
f[u]=0;
q.push(u);
}
}
while(!q.empty())
{
int r=q.front();q.pop();
for(int i=0;i<sigma_size;i++)
{
int u=ch[r][i];
if(!u){ ch[r][i]=ch[f[r]][i]; continue; }
q.push(u);
int v=f[r];
while(v && ch[v][i]==0) v=f[v];
f[u]=ch[v][i];
match[u] |= match[f[u]];
}
}
}
};
AC_Automata ac;
int dp[100+10][50+10][100];//dp[i][j]是一个最大长度为100的大数
void add(int *a,int *b)//大数相加,a[0]表示的是个位,a[1]是十位
{
int c=0;//进位
for(int i=0;i<100;i++)
{
int s=a[i]+b[i]+c;
a[i]=s%10;
c=s/10;
}
}
int main()
{
int n,m,p;
char str[50+20];
while(scanf("%d%d%d",&n,&m,&p)==3)
{
sigma_size=n;
ac.init();
scanf("%s",str);
for(int i=0;i<n;i++)
ac.mp[str[i]]=i;
while(p--)
{
scanf("%s",str);
ac.insert(str);
}
ac.getFail();
memset(dp,0,sizeof(dp));
dp[0][0][0]=1;//起点
for(int k=0;k<m;k++)
{
for(int i=0;i<ac.sz;i++)
if(ac.match[i]==0)
for(int j=0;j<sigma_size;j++)
if(ac.match[ac.ch[i][j]]==0)
add(dp[ac.ch[i][j]][k+1],dp[i][k]);
}
int ans[100];
memset(ans,0,sizeof(ans));//记得初始化
for(int i=0;i<ac.sz;i++)
if(ac.match[i]==0)
add(ans,dp[i][m]);
int i=99;
for(;i>=1;i--)
if(ans[i])
break;
for(;i>=0;i--)
printf("%d",ans[i]);
putchar('\n');
}
return 0;
}