【JZOJ5332】密码
(File IO): input:word.in output:word.out
Time Limits: 1000 ms Memory Limits: 524288 KB
Description
Input
Output
Sample Input
3 1
2 20
2
4
9
Sample Output
6
Data Constraint
Hint
解题思路
我们假设题目变成:给出了一个长数,那么你要逐位求出到这位为止的包含多少个秘钥。
如果可以维护这个,那么原题就可以差分一下用数位dp做:设
f[l][i][0/1][...]
表示,到i为为止,有l个秘钥,数字是否卡上界,并且……的方案数。
但是怎么维护?这是个多串匹配问题——AC自动机!
Aho-Corasick Automation AC自动机
AC自动机一般用来解决多串匹配问题,它是基于trie的能处理字符串匹配的算法
首先,对所有要匹配的字符串建一棵trie树
建trie树
void insert(int *a){
for(int x=0,i=1;i<=a[0];i++){
if(!nxt[x][a[i]])nxt[x][a[i]]=++tot,fa[tot]=x,deep[tot]=deep[x]+1,c[tot]=a[i];
x=nxt[x][a[i]];
}
}
那么,当询问一个被匹配的字符串s时,从s的每一位开始往树上跑,然后看一下能匹配到哪里,正确的但是这种方法会被卡成 O(len⋅deep)
那么我们想这种方法既然是处理字符串问题的,那么能不能像kmp一样处理失配的情况呢?
fail指针
当匹配失配时,我们应找一个位置继续匹配,和KMP类似,我们发现fail指针可以由father的fail推来:
假设当前位节点x,颜色为c,拿一个指针p,一开始
p=fail[father[x]]
如果p没有c子节点,那么p跳到它自己的fail,直到找到或跳到了根为止
特别的,当
father[x]=root
时
fail[x]=root
void getfail(){
int x,col,i,y,z;
for(head=0,que[tail=1]=0;head!=tail;){
for(x=que[++head],i=0;i<10;i++)if((y=nxt[x][i])){
if(deep[y]==1){fail[y]=0;que[++tail]=y;continue;}
for(z=fail[x],col=c[y];z && !nxt[z][col];z=fail[z]);
fail[y]=nxt[z][col];que[++tail]=y;
}
}
}
求值
我们需要求某个以位置i结尾的能被多少个字符串匹配
很简单,当放入一个串时,把串尾打上个+1标记
求fail时将标记和fail合并,于是就成了这样:
把板子放一下:
struct trie{
int c[300],deep[300],nxt[300][10],fa[300],tot,t[300],fail[300];
bool bz[300];
void insert(int *a){
for(int x=0,i=1;i<=a[0];i++){
if(!nxt[x][a[i]])nxt[x][a[i]]=++tot,fa[tot]=x,deep[tot]=deep[x]+1,c[tot]=a[i];
x=nxt[x][a[i]];if(i==a[0])t[x]++;
}
}
void getfail(){
int x,col,i,y,z;
for(head=0,que[tail=1]=0;head!=tail;){
for(x=que[++head],i=0;i<10;i++)if((y=nxt[x][i])){
if(deep[y]==1){fail[y]=0;que[++tail]=y;continue;}
for(z=fail[x],col=c[y];z && !nxt[z][col];z=fail[z]);
fail[y]=nxt[z][col];que[++tail]=y;
}
}
}
void count(){
bz[0]=1;
int x,col,i,y,z;
for(int i=1;i<=tot;i++)if(!bz[i]){
for(head=0,bz[que[tail=1]=i]=1;head!=tail;)if(!bz[y=fail[que[++head]]])bz[que[++tail]=y]=1;
for(i=tail;i;i--)t[que[i]]+=t[fail[que[i]]];
}
}
};
正解
回到原题
设数位DP
f[l][i][j][0/1]
表示完成了l个秘钥,匹配到i,trie跑到了j,是否卡上界,的方案数,转移时枚举下一个数x,往fail链跑一跑,找到一个p,p有x子节点,但这样会超时
其实预处理一下
nex[i][j]
表示trie的i的位置,要找下一个数j应该走到哪儿。
So easy!
#include<cstring>
#include<cstdio>
#define mo 1000000007
#define min(a,b) ((a)<(b)?(a):(b))
using namespace std;
int n,k,lim[1001],a[200][600],que[300],head,tail;
long long f[11][600][300][2],ans;
char x[1001],y[1001],tmp[1001];
struct trie{
int c[300],deep[300],nxt[300][10],fa[300],tot,t[300],fail[300],nex[300][10];
bool bz[300];
void insert(int *a){
for(int x=0,i=1;i<=a[0];i++){
if(!nxt[x][a[i]])nxt[x][a[i]]=++tot,fa[tot]=x,deep[tot]=deep[x]+1,c[tot]=a[i];
x=nxt[x][a[i]];if(i==a[0])t[x]++;
}
}
void getfail(){
int x,col,i,y,z;
for(head=0,que[tail=1]=0;head!=tail;){
for(x=que[++head],i=0;i<10;i++)if((y=nxt[x][i])){
if(deep[y]==1){fail[y]=0;que[++tail]=y;continue;}
for(z=fail[x],col=c[y];z && !nxt[z][col];z=fail[z]);
fail[y]=nxt[z][col];que[++tail]=y;
}
for(int i=0;i<10;i++){
for(y=x;y && !nxt[y][i];y=fail[y]);
nex[x][i]=nxt[y][i];
}
}
}
void count(){
bz[0]=1;
int x,col,i,y,z;
for(int i=1;i<=tot;i++)if(!bz[i]){
for(head=0,bz[que[tail=1]=i]=1;head!=tail;)if(!bz[y=fail[que[++head]]])bz[que[++tail]=y]=1;
for(i=tail;i;i--)t[que[i]]+=t[fail[que[i]]];
}
}
}tr;
long long ask(){
memset(f,0,sizeof(f));
f[0][0][0][1]=1;int a;
for(int l=0;l<=k;l++)
for(int i=0;i<lim[0];i++)
for(int j=0;j<=tr.tot;j++){
for(int x=0;x<=9;x++){
int p=tr.nex[j][x],a=min(l+tr.t[p],k);
f[a][i+1][p][0]=(f[a][i+1][p][0]+f[l][i][j][0]+(x<lim[i+1]?f[l][i][j][1]:0))%mo;
if(x==lim[i+1])f[a][i+1][p][1]=(f[a][i+1][p][1]+f[l][i][j][1])%mo;
}
}
long long ans=0;
for(int i=0;i<=tr.tot;i++)ans=(ans+f[k][lim[0]][i][0]+f[k][lim[0]][i][1])%mo;
return (ans+mo)%mo;
}
int main(){
freopen("word.in","r",stdin);
freopen("word.out","w",stdout);
scanf("%d %d\n",&n,&k);
scanf("%s",x+1);scanf("%s",y+1);
for(int i=1;i<=n;i++){
scanf("%s",tmp+1);a[i][0]=strlen(tmp+1);
for(int j=1;j<=a[i][0];j++)a[i][j]=tmp[j]-'0';
tr.insert(a[i]);
}tr.getfail();tr.count();
lim[0]=strlen(y+1);for(int i=1;i<=lim[0];i++)lim[i]=y[i]-'0';
ans=ask();
lim[0]=strlen(x+1);for(int i=1;i<=lim[0];i++)lim[i]=x[i]-'0';
ans-=ask();ans=(ans+mo)%mo;printf("%lld",ans);
fclose(stdin);fclose(stdout);
return 0;
}