【BZOJ4861】[Beijing2017]魔法咒语
题意:别看BZ的题面了,去看LOJ的题面吧~
题解:显然,数据范围明显的分成了两部分:一个是L很小,每个基本词汇长度未知;一个是L很大,每个基本词汇的长度是1或2。看来只能写两份代码了。
对于L很小的,我们先将禁忌串建成一个AC自动机,然后预处理出to[i][j]表示AC自动机中的第i个节点在加入基本词汇j后会到达的节点。然后设f[i][j]表示总长度为i,匹配到第j个节点的方案数。然后DP一下就好了。
对于L很大的,我们想到矩乘,设ans[i][j]表示总长度为i,匹配到第j个节点的方案数。但是ans[i]这个矩阵由ans[i-1]和ans[i-2]两个矩阵转移过来,所以我们直接用分块矩阵的乘法,即:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
using namespace std;
typedef long long ll;
const ll mod=1000000007;
int n,m,N,M,L,mx,sum;
struct mat
{
ll v[210][210];
mat (){memset(v,0,sizeof(v));}
ll* operator [](int a){return v[a];}
mat operator * (mat a)
{
mat ret;
int i,j,k;
for(i=1;i<=2*M;i++) for(j=1;j<=2*M;j++) for(k=1;k<=2*M;k++) (ret[i][j]+=v[i][k]*a[k][j])%=mod;
return ret;
}
}ans,x;
int l1[60],to[110][60];
ll f[110][110];
queue<int> q;
struct node
{
int ch[26],fail,cnt;
}p[110];
char s1[60][110],s2[60][110];
void build()
{
q.push(1);
int i,j,k,a,u;
while(!q.empty())
{
u=q.front(),q.pop();
for(i=0;i<26;i++)
{
if(!p[u].ch[i])
{
if(u==1) p[u].ch[i]=1;
else p[u].ch[i]=p[p[u].fail].ch[i];
continue;
}
q.push(p[u].ch[i]);
if(u==1)
{
p[p[u].ch[i]].fail=1;
continue;
}
p[p[u].ch[i]].fail=p[p[u].fail].ch[i];
p[p[u].ch[i]].cnt|=p[p[p[u].fail].ch[i]].cnt;
}
}
for(i=1;i<=M;i++) for(j=1;j<=n;j++)
{
u=i,a=strlen(s1[j]);
if(p[u].cnt) to[i][j]=-1;
for(k=0;k<a;k++)
{
u=p[u].ch[s1[j][k]-'a'];
if(p[u].cnt) break;
}
if(k==a) to[i][j]=u;
else to[i][j]=-1;
}
}
void DP()
{
int i,j,k,a;
f[0][1]=1;
for(i=0;i<L;i++) for(j=1;j<=M;j++) for(k=1;k<=n;k++)
{
if(to[j][k]==-1) continue;
a=strlen(s1[k]);
if(a+i<=L) (f[a+i][to[j][k]]+=f[i][j])%=mod;
}
for(i=1;i<=M;i++) sum=(sum+f[L][i])%mod;
printf("%d",sum);
}
void pm(int y)
{
while(y)
{
if(y&1) ans=ans*x;
x=x*x,y>>=1;
}
}
void MM()
{
int i,j;
for(i=1;i<=M;i++)
{
for(j=1;j<=n;j++)
{
if(to[i][j]==-1) continue;
if(strlen(s1[j])==1) x[i][to[i][j]]++;
else x[i+M][to[i][j]]++;
}
x[i][i+M]++;
}
ans[1][1]=1;
pm(L);
for(i=1;i<=M;i++) sum=(sum+ans[1][i])%mod;
printf("%d",sum);
}
int main()
{
scanf("%d%d%d",&n,&m,&L);
int i,j,a,b,u;
N=1,M=1;
for(i=1;i<=n;i++) scanf("%s",s1[i]),a=strlen(s1[i]),mx=max(mx,a);
for(i=1;i<=m;i++)
{
scanf("%s",s2[i]),a=strlen(s2[i]);
for(u=1,j=0;j<a;j++)
{
b=s2[i][j]-'a';
if(!p[u].ch[b]) p[u].ch[b]=++M;
u=p[u].ch[b];
}
p[u].cnt=1;
}
build();
if(mx<=2) MM();
else DP();
return 0;
}