HDU 2825 Wireless Password(AC自动机+状态压缩DP)
http://acm.hdu.edu.cn/showproblem.php?pid=2825
题意:
现在要你推断一个长度==n的由小写字母构成的字符串S有多少种组成方式.其中这个S至少包含字典集合中的k个单词.字典集合中有m个单词并已给出.
分析:
该题要用刘汝佳:训练指南上带match的AC自动机。每个单词有一个编号,match[j]表示j节点表示的单词状态,比如j节点表示单词5和单词3,那么match[j]== 1<<5 | 1<<3。即该match[j]值的二进制形式正好有2位为1,是第3位和第5位。
令d[i][j][k]=x表示当前走了i步,处于AC自动机的j号节点,且字符串中的单词出现情况为k(k的二进制形式表示集合)时的情况种数为x。
d[i+1][j][k|match[j]] += d[i][j1][k]
从j1号节点能走到j号节点,且j号节点的后缀单词覆盖情况为match[j],那么从j1节点走到j节点后,总的单词覆盖情况就是 k|match[j] 了(想想是不是)。
初值为:d[0][0][0]=1,其他所有d都为0。
ans = d[n][i][k] 其中i为0到sz-1的任意节点,而集合k的二进制形式至少要包含K个1。
AC代码:
#include<cstdio>
#include<queue>
#include<cstring>
using namespace std;
const int MOD=20090717;
const int maxnode=100+10;
const int sigma_size=26;
int dp[30][maxnode][(1<<10)+100];//开始这么定义dp[][][1<<10+100],结果内存溢出出现了BUG错误
int num[(1<<10)+10];//num[i]=x表示i的二进制有x个1
int N,M,K;
struct AC_Automata
{
int ch[maxnode][sigma_size];
int f[maxnode];
int match[maxnode];//此处的match是一个2进制位集合,不再是0或1了
int sz;
void init()
{
sz=1;
memset(ch[0],0,sizeof(ch[0]));
f[0]=match[0]=0;
}
void insert(char *s,int v)
{
int n=strlen(s),u=0;
for(int i=0;i<n;i++)
{
int id=s[i]-'a';
if(ch[u][id]==0)
{
ch[u][id]=sz;
memset(ch[sz],0,sizeof(ch[sz]));
match[sz++]=0;
}
u=ch[u][id];
}
match[u] |=1<<v;
}
void getFail()
{
f[0]=0;
queue<int> q;
for(int i=0;i<sigma_size;i++)
{
int u=ch[0][i];
if(u)
{
f[u]=0;
q.push(u);
}
}
while(!q.empty())
{
int r=q.front();q.pop();
for(int i=0;i<sigma_size;i++)
{
int u=ch[r][i];
if(!u) { ch[r][i]=ch[f[r]][i]; continue; }
q.push(u);
int v=f[r];
while(v && ch[v][i]==0) v=f[v];
f[u] = ch[v][i];
match[u] |= match[f[u]];
}
}
}
int solve()
{
for(int i=0;i<=N;i++)
for(int j=0;j<sz;j++)
for(int st=0;st<(1<<M);st++)
dp[i][j][st]=0;
dp[0][0][0]=1;
for(int i=0;i<N;i++)//当前行走的长度i
{
for(int j=0;j<sz;j++)//当前所在的j号节点
{
for(int st=0;st<(1<<M);st++)if(dp[i][j][st]>0)//当前状态st
{
for(int k=0;k<sigma_size;k++)
{
dp[i+1][ch[j][k]][(st|match[ch[j][k]])] =(dp[i+1][ch[j][k]][(st|match[ch[j][k]])] + dp[i][j][st])%MOD ;
}
}
}
}
int ans=0;
for(int st=0;st<(1<<M);st++)if(num[st]>=K)
for(int i=0;i<sz;i++)
ans = (ans +dp[N][i][st])%MOD;
return ans;
}
}ac;
int main()
{
memset(num,0,sizeof(num));
for(int st=1;st<(1<<10);st++)
{
for(int j=0;j<10;j++)
if(st&(1<<j)) num[st]++;
}
while(scanf("%d%d%d",&N,&M,&K)==3&&N)
{
if(N==0&&M==0&&K==0) break;
ac.init();
for(int i=0;i<M;i++)
{
char str[100];
scanf("%s",str);
ac.insert(str,i);
}
ac.getFail();
printf("%d\n",ac.solve());
}
return 0;
}