ac自动机+dp+高精度
ac自动机的节点作为状态,dp[i][j]表示长度为i状态为j的种类数。转移时注意已经是串的节点不能转移并且不能被转移即可。
需要注意的一点是输入的字符的ascll码有负数。
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
using namespace std;
const int N=500;
char use[501];
int dp[55][555][55],ans[55];
struct
{
int next[501],tmp,fail;
}trie[555];
struct matrix
{
int data[511][511],size;
matrix(int tmp,int lon)
{
size=lon;
for(int i=0;i<=size;i++)
for(int j=0;j<=size;j++)
data[i][j]=tmp;
}
};
int lon;
void trieini()
{
memset(trie,0,sizeof(trie));
lon=0;
}
void insert(char s[])
{
int t=0;
int n=strlen(s+1);
for(int i=1;i<=n;i++)
{
if(trie[t].next[s[i]+200]==0)
trie[t].next[s[i]+200]=++lon;
t=trie[t].next[s[i]+200];
if(i==n)
trie[t].tmp++;
}
}
void getfail()
{
int root=0;
trie[root].fail=root;
queue <int> q;
for(int i=0;i<=N;i++)
{
if(trie[root].next[i])
{
int u=trie[root].next[i];
trie[u].fail=root;
q.push(u);
}
}
while(!q.empty())
{
int t=q.front();
q.pop();
for(int i=0;i<=N;i++)
if(trie[t].next[i])
{
int u=trie[t].next[i];
int tmp=trie[t].fail;
while(tmp!=root&&trie[tmp].next[i]==0)
tmp=trie[tmp].fail;
if(trie[tmp].next[i])
trie[u].fail=trie[tmp].next[i];
else
trie[u].fail=root;
q.push(u);
}
}
for(int i=1;i<=lon;i++)
{
int u=trie[i].fail;
while(u!=root)
{
trie[i].tmp+=trie[u].tmp;
u=trie[u].fail;
}
}
}
void find(matrix &a)
{
int root=0;
for(int k=0;k<=lon;k++)
{
if(trie[k].tmp) continue;
for(int i=0;i<=N;i++)
{
if(!use[i])
continue;
int t=k;
while(t!=root&&trie[t].next[i]==0)
t=trie[t].fail;
if(trie[t].next[i])
{
int u=trie[t].next[i];
if(!trie[u].tmp)
a.data[k][u]++;
}
else
a.data[k][0]++;
}
}
}
void cal(int a[],int b[])
{
for(int i=1;i<=50;i++)
{
a[i]+=b[i];
a[i+1]+=a[i]/10000;
a[i]%=10000;
}
}
void prin(int a[])
{
int k=50;
while(a[k]==0&&k>1) k--;
printf("%d",a[k]);
for(int i=k-1;i>=1;i--)
printf("%04d",a[i]);
printf("\n");
}
int main()
{
int n,m,p;
while(scanf("%d %d %d",&n,&m,&p)!=EOF)
{
memset(use,0,sizeof(use));
trieini();
char a[101];
scanf("%s",a+1);
for(int i=1;i<=n;i++)
{
use[a[i]+200]=1;
}
for(int i=1;i<=p;i++)
{
scanf("%s",a+1);
insert(a);
}
getfail();
matrix aa(0,lon);
find(aa);
memset(ans,0,sizeof(ans));
memset(dp,0,sizeof(dp));
dp[0][0][1]=1;
for(int i=0;i<m;i++)
{
for(int k=0;k<=lon;k++)
{
for(int j=0;j<=lon;j++)
for(int p=1;p<=aa.data[j][k];p++)
cal(dp[i+1][k],dp[i][j]);
}
}
for(int i=0;i<=lon;i++)
cal(ans,dp[m][i]);
prin(ans);
}
return 0;
}