题意:
求有多少个长度为L,且包含n个串的字符串,答案小于等于42时输出方案。(n<=10,L<=25,|S|<=10)
思路:
常规的AC自动机上的壮压DP,通过Fail去掉那些被包含的字符串即可。
d
p
[
i
]
[
j
]
[
s
]
:
dp[i][j][s]:
dp[i][j][s]:走了
i
i
i步,当前在点
j
j
j,已经走过的串的二进制为
s
s
s的方案数。
d
p
[
i
+
1
]
[
c
h
[
j
]
[
c
]
]
[
s
∣
e
d
[
c
h
[
j
]
[
c
]
]
]
+
=
d
p
[
i
]
[
j
]
[
s
]
dp[i+1][ch[j][c]][s|ed[ch[j][c]]]+=dp[i][j][s]
dp[i+1][ch[j][c]][s∣ed[ch[j][c]]]+=dp[i][j][s]
a
n
s
+
=
d
p
[
L
]
[
i
]
[
S
(
全
集
)
]
ans+=dp[L][i][S(全集)]
ans+=dp[L][i][S(全集)]
由于答案小于等于42才输出方案,这个时候所有n个字符串都是紧密相连的,可以从终止状态开始搜索,也可以n!枚举每个字符串的相对位置,由于已经去掉了包含关系,所以相邻的字符串之间尽量多的重叠即可。
#include<cstdio>
#include<queue>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 105
#define MAXM 1024
#define LL long long
LL dp[30][MAXN][MAXM],sum;
int ch[MAXN][26],tot,cnt,S,fail[MAXN],ed[MAXN];
char s[15],mem[MAXN],a[30];
vector<string>ans;
void Insert()
{
int nw=0,len=strlen(s);
for(int i=0;i<len;i++)
{
int c=s[i]-'a';
if(ch[nw][c]==0)
{
ch[nw][c]=++tot;
mem[tot]=s[i];
}
nw=ch[nw][c];
}
ed[nw]=1;
}
queue<int>Q;
void Build()
{
for(int i=0;i<26;i++) if(ch[0][i]) Q.push(ch[0][i]);
while(!Q.empty())
{
int u=Q.front();Q.pop();
for(int i=0;i<26;i++)
if(!ch[u][i]) ch[u][i]=ch[fail[u]][i];
else
{
int v=ch[u][i];
fail[v]=ch[fail[u]][i];
ed[v]|=ed[fail[v]];
Q.push(v);
}
}
for(int i=0;i<=tot;i++) ed[fail[i]]=0;
for(int i=0;i<=tot;i++) if(ed[i]) ed[i]=1<<cnt,cnt++;
}
void dfs(int len,int pos,int sta)
{
a[len-1]=mem[pos];
if(len==1)
{
ans.push_back(a);
return;
}
int c=mem[pos]-'a';
for(int i=0;i<=tot;i++)
if(dp[len-1][i][sta]&&ch[i][c]==pos) dfs(len-1,i,sta);
if(ed[pos])
{
int s=sta^ed[pos];
for(int i=0;i<=tot;i++)
if(dp[len-1][i][s]&&ch[i][c]==pos) dfs(len-1,i,s);
}
}
int main()
{
int L,N;
scanf("%d%d",&L,&N);
for(int i=0;i<N;i++)
{
scanf("%s",s);
Insert();
}
Build();
S=1<<cnt;cnt=0;
dp[0][0][0]=1;
for(int i=0;i<L;i++)
for(int j=0;j<=tot;j++)
for(int s=0;s<S;s++)
if(dp[i][j][s])
{
for(int k=0;k<26;k++)
dp[i+1][ch[j][k]][s|ed[ch[j][k]]]+=dp[i][j][s];
}
for(int i=0;i<=tot;i++)
sum+=dp[L][i][S-1];
printf("%lld\n",sum);
if(sum<=42)
{
for(int i=0;i<=tot;i++)
if(dp[L][i][S-1])
dfs(L,i,S-1);
sort(ans.begin(),ans.end());
for(int i=0;i<ans.size();i++)
cout<<ans[i]<<'\n';
}
}