数位DP+AC自动机似乎是常见套路?
首先题目一看就能想到数位DP。同时题目还涉及到多模式串匹配,于是又要用到AC自动机。
于是,我们可以初步得到状态
d[i][j]
d
[
i
]
[
j
]
,表示当前在第i位,在AC自动机中匹配到第j个节点时的方案数。
然而有以下几个问题:
第一,题目要求的是比n小的满足条件的数的个数,按照经典做法,我们应该加一维表示第i位之前是否与n相同。
第二,是匹配串中有前导0,而数中是不能的,因此状态就得变一变。
我的做法是将状态拆成
d[i][j][0,1,2]
d
[
i
]
[
j
]
[
0
,
1
,
2
]
分别代表正常状态,第i位之前都与n相同,以及第i位之前都为0.
然后三种状态同时转移:
d[i][j][0]=∑d[i−1][c[j][k]][0]
d
[
i
]
[
j
]
[
0
]
=
∑
d
[
i
−
1
]
[
c
[
j
]
[
k
]
]
[
0
]
d[i][j][1]=d[i−1][c[j][n[i]]][1]+∑n[i]−1k=0d[i−1][c[j][k]][0]
d
[
i
]
[
j
]
[
1
]
=
d
[
i
−
1
]
[
c
[
j
]
[
n
[
i
]
]
]
[
1
]
+
∑
k
=
0
n
[
i
]
−
1
d
[
i
−
1
]
[
c
[
j
]
[
k
]
]
[
0
]
d[i][j][2]=d[i−1][j][2]+∑9k=1d[i−1][j][0]
d
[
i
]
[
j
]
[
2
]
=
d
[
i
−
1
]
[
j
]
[
2
]
+
∑
k
=
1
9
d
[
i
−
1
]
[
j
]
[
0
]
其中从
c[j][k]
c
[
j
]
[
k
]
表示第
j
j
个节点的第个儿子
然后我们可以滚动数组压掉一维,然后正常地DP就可以了。
恩,也许有人会问,fail指针去哪里了?
恩,这是一个来自刘汝佳的书的神奇优化,在建AC自动机的bfs中,如果c[u][i]不存在的话(u为队头),那么就把c[u][i]设成c[fail[u]][i],这样在具体匹配中就彻底不需要fail指针了。
为什么我的代码在LOJ和Luogu上都AC了就在BZOJ上WA,呜~~~~
具体实现见代码如下:
#include <bits/stdc++.h>
#define LL long long
#define MOD 1000000007
using namespace std;
char n[1210],temp[1210];
int m,top,len,c[1510][10],f[1510],q[1510],head=1,tail;
bool b[1510];
LL d[2][1510][3],ret;//位数,自动机状态,正常状态,本位之前全与n一致 ,本位之前全为0
void insert()
{
int j=0;
for(int i=0;i<strlen(temp);i++)
{
j=(c[j][temp[i]-'0']?c[j][temp[i]-'0']:c[j][temp[i]-'0']=++top);
}
b[j]=1;
}
void CreateAC()
{
for(int i=0;i<10;i++)
{
if(c[0][i])q[++tail]=c[0][i];
}
while(head<=tail)
{
int u=q[head++];
for(int i=0;i<10;i++)
{
int v=c[u][i];
if(v)
{
q[++tail]=v;
int vv=f[u];
while(vv&&(!c[vv][i]))vv=f[vv];
f[v]=c[vv][i];
}
else
{
c[u][i]=c[f[u]][i];
}
}
}
}
LL dp()
{
for(int i=0;i<=top;i++)
{
if(b[i])continue;
for(int j=0;j<2;j++)
{
d[1][i][j]=1;
}
}
for(register int i=0,cur=0;i<len-1;i++,cur^=1)
{
memset(d[cur],0,sizeof(d[cur]));
for(register int j=0;j<=top;j++)
{
if(b[j])continue;
for(register int k=0;k<(j?2:3);k++)
{
for(register int l=0;l<=(k==1?n[i]-'0':9);l++)
{
if(l==0&&k==2)
{
d[cur][j][2]+=d[cur^1][j][2];
}
else if(l==n[i]-'0'&&k==1)
{
d[cur][j][1]+=d[cur^1][c[j][l]][1];
}
else
{
d[cur][j][k]+=d[cur^1][c[j][l]][0];
}
d[cur][j][k]%=MOD;
}
}
}
}
for(int i=0;i<=n[len-1]-'0';i++)
{
(ret+=d[len&1][i==0?0:c[0][i]][i==0?2:(i==n[len-1]-'0'?1:0)])%=MOD;
}
return ret;
}
int main()
{
scanf("%s\n%d",n,&m);
len=strlen(n);
for(int i=0,j=len-1;i<j;i++,j--)swap(n[i],n[j]);
for(int i=1;i<=m;i++)
{
scanf("%s",temp);
insert();
}
CreateAC();
cout<<dp()<<endl;
}