编者:第一次做这种题,以前做的都是AC自动机模板题,突然发现简单题都做过了,只能去做关于DP或矩阵的了,稍微错一点就超时好恶心呀,特别是把dp的第三维写成1024就超时了,无语。。。
题意:给你长度为十的十个单词,问你长度为字串中至少包含k个单词(单词间可以重叠)个数有多少。
解法:十个单词可以用十位0或1来表示使用与否 例如 0000000001 表示只用了单词1 这样状态一共有 1024个,这个就是状态压缩。把十个单词做成ac自动机后,就可以dp了。ac自动机的val数组存的就是每个单词自己的编号状态了 例如 单词 5-> 0000010000。 dp[i][j][s]中i表示当前串的长度,j表示所在节点的编号(空的都是0),s表示包含单词的状态,而c表示字母列表 。状态转移方程: dp[i][ac.ch[j][c]][s|tnum[ac.ch[j][c]]] += dp[i][j][s]。最后累加起来就好了。
#include <iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<map>
#include<queue>
#include<cmath>
#include<vector>
#define inf 0x3f3f3f3f
#define Inf 0x3FFFFFFFFFFFFFFFLL
#define pi acos(-1.0)
#define eps 1e-8
using namespace std;
const int mod = 20090717;
const int maxnode = 102;
const int char_size = 26;
struct autoAC
{
int ch[maxnode][char_size], val[maxnode], f[maxnode], last[maxnode];
int sz;
int idx(char c) {return c-'a';}
void init() {memset(ch[0], 0,sizeof ch[0]);sz=1;}
void insert(char* s, int v = 1)
{
int u = 0, n = strlen(s);
for(int i = 0 ; i < n ; ++ i)
{
int c = idx(s[i]);
if(!ch[u][c])
{
memset(ch[sz],0,sizeof ch[sz]);
val[sz]=0;
ch[u][c] = sz++;
}
u = ch[u][c];
}
val[u] |= (1<<v);
}
void getFail()
{
queue<int> q;
f[0] = 0;
for(int c = 0 ; c < char_size ; ++ c)
{
int u = ch[0][c];
if(u) {f[u]=last[u]=0; q.push(u);}
}
while(!q.empty())
{
int r = q.front(); q.pop();
for(int c = 0 ; c < char_size ; ++ c)
{
int u = ch[r][c];
if(!u) {ch[r][c]=ch[f[r]][c]; continue;}
q.push(u);
int v = f[r];
while(v&&!ch[v][c]) v = f[v];
f[u] = ch[v][c];
last[u] = val[f[u]]?f[u]:last[f[u]];
}
}
}
}ac;
int dp[26][maxnode][1024+1];
int main()
{
//freopen("in.txt","r",stdin);
int num[1024];//表示一个数字二进制形式有多少个1
for(int i = 0 ; i < 1024 ; ++ i)
{
int j = i, res = 0;
while(j)
{
res += (1&j);
j>>=1;
}
num[i] = res;
}
int n, m, k;
while(~scanf("%d%d%d",&n,&m,&k)&&(n||m||k))
{
char s[maxnode];
ac.init();
for(int i = 0 ; i < m ; ++ i)
{
scanf("%s",s);
ac.insert(s,i);
}
ac.getFail();
int tnum[maxnode];
memset(tnum,0,sizeof tnum);
for(int i = 0 ; i < ac.sz ; ++ i)
{
int j = i;
while(j)
{
tnum[i]|=ac.val[j];
j = ac.last[j];
}
}
memset(dp, 0, sizeof dp);
dp[0][0][0] = 1;
for(int i = 0 ; i < n ; ++ i)
for(int j = 0 ; j < ac.sz ; ++ j)
for(int s = 0 ; s < (1<<m) ; ++ s)
{
if(dp[i][j][s])
{
for(int c = 0 ; c < char_size ; ++ c)
{
int u = ac.ch[j][c], sb = s|tnum[u];
dp[i+1][u][sb] += dp[i][j][s];
if(dp[i+1][u][sb]>=mod) dp[i+1][u][sb]-=mod;
}
}
}
int ans = 0;
for(int i = 0 ; i < ac.sz ; ++ i)
{
for(int j = 0 ; j < (1<<m) ; ++ j)
{
if(num[j]>=k) ans += dp[n][i][j];
ans%=mod;
}
}
printf("%d\n",ans);
}
return 0;
}