在学习AC自动机前,请确保你已经充分理解
KMP算法 ANDTrie字典树
我们将从这样一个问题开始引入AC自动机
Q:给定n个模式串和1个文本串,求有多少个模式串在文本串中出现过
这个问题要怎么解?
用N次KMP吗,这样显然爆炸啊
于是闲着没事干脑袋又十分丰腴的计算机科学家们有了一个奇妙的想法
在Trie上求KMP!(当然实际上只是类似KMP的nxt,定义还是有所不同的)
假设当前有5个模式串’she’, ‘he’, ‘say’, ‘shr’, ‘her’
先建出他们的字典树
建好字典树后我们效仿KMP的nxt数组
在Trie上增加fail失配指针
什么是fail指针
假设当前结点
u
u
u所代表的串为
S
S
S,那么
u
u
u的
f
a
i
l
fail
fail指针指向
最长的,能与
S
S
S的后缀匹配的的Trie树的前缀 的结尾结点
(这都什么 #$*&@%¥^#)
是不是有点被绕晕了,那就看这图感性理解一下吧
也就是说若当前Trie树上结点
u
u
u表示的串为
S
S
S
那么
f
a
i
l
[
u
]
fail[u]
fail[u]指向的结点
v
v
v代表串
T
T
T,一定能与
S
S
S的后缀匹配,且是所有满足条件的
T
T
T中最长的
比如最长的,能与串sh的后缀匹配的 Trie的前缀,只有串h
以及最长的,能与串she的后缀匹配的 Trie的前缀,只有串he
那么这个fail指针要怎么求呢
可以考虑用BFS实现
假设当前从队首取出结点
u
u
u
对于
u
u
u的一个子节点
c
h
[
u
]
[
i
]
ch[u][i]
ch[u][i]
我们从
u
u
u开始不断沿着
f
a
i
l
fail
fail指针向上跳
直到跳到一个结点
v
v
v也有表示字符
i
i
i的子节点
c
h
[
v
]
[
i
]
ch[v][i]
ch[v][i]
那么
c
h
[
u
]
[
i
]
ch[u][i]
ch[u][i]的
f
a
i
l
fail
fail指针指向
c
h
[
v
]
[
i
]
ch[v][i]
ch[v][i]
特别的,如果一直跳到根都没有符合条件的结点
那么
c
h
[
u
]
[
i
]
ch[u][i]
ch[u][i]的
f
a
i
l
fail
fail指针指向根
以及注意所有第二层的结点
f
a
i
l
fail
fail指针都指向根
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);//第二层节点fail都指向根
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
{
if(!ch[u][i]) continue;//没有这个子节点就跳过
int tt=fail[u];
while(!ch[tt][i]&&tt) tt=fail[tt];//沿着fail指针找到第一个也有同样子节点的结点
fail[ch[u][i]]=ch[tt][i];
q.push(ch[u][i]);
}
}
}
现在连好了fail指针,匹配就简单了
首先用一个指针指向根
将文本串一位一位送入自动机
若当前指针存在表示文本串下一位的子节点,令指针移向该子节点
否则沿着fail指针不断转移,直到跳到一个存在该子节点的结点,令指针移向该子节点
指针每跳转完成一次,就沿着fail指针统计一次
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
while(!ch[u][x]&&u) u=fail[u];
u=ch[u][x];
for(int t=u;t&&sum[t]!=-1;t=fail[t])
ans+=sum[t],sum[t]=-1;
}
}
Trie图优化
我们发现构造AC自动机以及匹配的时候
如果当前结点没有相应的子结点,那么就要沿着fail走到第一个有相应的儿子的结点,失配过程挺麻烦的
实践中通常把AC自动机改造一下,把没有的边补上
即若
u
u
u不存在孩子
i
i
i,那么就补上这个孩子为
f
a
i
l
[
u
]
fail[u]
fail[u]的孩子
i
i
i
若存在这个孩子,则
f
a
i
l
[
c
h
[
u
]
[
i
]
]
fail[ch[u][i]]
fail[ch[u][i]]直接指向
c
h
[
f
a
i
l
[
u
]
[
i
]
]
ch[fail[u][i]]
ch[fail[u][i]]
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[i][0]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
if(ch[u][i]) fail[ch[u][i]]=ch[fail[u]][i],q.push(ch[u][i]);
else ch[u][i]=ch[fail[u]][i];
}
}
这样构造之后得到的图称为
T
r
i
e
Trie
Trie图
补全后匹配时失配时就不用一直跳fail了
void query(char *ss)
{
int len=strlen(ss),u=0;
for(int i=0;i<len;++i)
{
u=ch[u][ss[i]-'a'];
for(int t=u;t&&sum[t]!=-1;t=fail[t])
ans+=sum[t],sum[t]=-1;
}
}
当然并不是所有情况都能适用这个优化,有时还是需要保留原来的结构 (具体什么情况蒟蒻也还没完全明白)
AC自动机の应用
HDU - 2222 Keywords Search
上述问题的果题
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return x*f;
}
const int maxn=500010;
int Q,n,cnt;
char pat[maxn],txt[maxn<<1];
int ch[maxn][26],fail[maxn],sum[maxn];
queue<int> q;
int ans;
void ins(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
if(!ch[u][x]) ch[u][x]=++cnt;
u=ch[u][x];
}
sum[u]++;
}
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
{
if(!ch[u][i]) continue;
int tt=fail[u];
while(!ch[tt][i]&&tt) tt=fail[tt];
fail[ch[u][i]]=ch[tt][i];
q.push(ch[u][i]);
}
}
}
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
while(!ch[u][x]&&u) u=fail[u];
u=ch[u][x];
for(int t=u;t&&sum[t]!=-1;t=fail[t])
ans+=sum[t],sum[t]=-1;
}
}
void init()
{
ans=cnt=0;
memset(sum,0,sizeof(sum));
memset(ch,0,sizeof(ch));
}
int main()
{
Q=read();
while(Q--)
{
n=read(); init();
for(int i=1;i<=n;++i)
{
scanf("%s",&pat);
ins(pat,strlen(pat));
}
scanf("%s",&txt);
build_AC(); query(txt,strlen(txt));
printf("%d\n",ans);
}
return 0;
}
P3796 【模板】AC自动机(加强版)
Q:有N个由小写字母组成的模式串以及一个文本串T。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串T中出现的次数最多。
也是稍作修改即可的果题
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return x*f;
}
const int maxn=100010;
int n;
char pt[200][100],txt[maxn*10];
int ch[maxn][26],fail[maxn],cnt;
int id[maxn],num[200],ans;
queue<int> q;
void ins(char *ss,int k)
{
int len=strlen(ss),u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
if(!ch[u][x]) ch[u][x]=++cnt;
u=ch[u][x];
}
id[u]=k;
}
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
if(ch[u][i]) fail[ch[u][i]]=ch[fail[u]][i],q.push(ch[u][i]);
else ch[u][i]=ch[fail[u]][i];
}
}
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
u=ch[u][ss[i]-'a'];
for(int t=u;t;t=fail[t])
num[id[t]]++;
}
for(int i=1;i<=n;++i) ans=max(ans,num[i]);
}
void init()
{
ans=cnt=0;
memset(ch,0,sizeof(ch));
memset(id,0,sizeof(id));
memset(num,0,sizeof(num));
}
int main()
{
while(scanf("%d",&n)!=EOF)
{
if(n==0) break; init();
for(int i=1;i<=n;++i)
{
scanf("%s",&pt[i]);
ins(pt[i],i);
}
build_AC();
scanf("%s",&txt);
query(txt,strlen(txt));
printf("%d\n",ans);
for(int i=1;i<=n;++i)
if(num[i]==ans) printf("%s\n",pt[i]);
}
return 0;
}
洛谷P5231 [JSOI2012]玄武密码
给定一个文本串s和m个模式串t,求每个模式串的最长前缀p满足p是文本串s的子串
模式串构建fail指针,文本串送入匹配
文本串每次沿fail跳转经过的结点u所代表的以u为结尾的前缀,一定是s的子串
标记一下,再用模式串模拟建一次Trie,找到标记的最深的节点即可
#include<iostream>
#include<cstdio>
#include<cmath>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long lt;
typedef unsigned int ui;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
const int maxn=10000010;
int n,m;
char txt[maxn],pt[500010][105];
int ch[maxn][4],fail[maxn];
int dep[maxn],rem[maxn],cnt;
int get(char s)
{
if(s=='E') return 0;
else if(s=='S') return 1;
else if(s=='W') return 2;
else return 3;
}
void ins(char *ss)
{
int u=0,len=strlen(ss);
for(int i=0;i<len;++i)
{
int x=get(ss[i]);
if(!ch[u][x]) ch[u][x]=++cnt,dep[cnt]=dep[u]+1;
u=ch[u][x];
}
//rem[u]=1;
}
void ACM()
{
queue<int> q;
for(int i=0;i<4;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<4;++i)
{
if(!ch[u][i]) ch[u][i]=ch[fail[u]][i];
else fail[ch[u][i]]=ch[fail[u]][i],q.push(ch[u][i]);
}
}
}
void solve(char *ss)
{
int u=0,len=strlen(ss);
for(int i=0;i<len;++i)
{
u=ch[u][get(ss[i])];
for(int t=u;t;t=fail[t])
rem[t]=1;
}
}
int query(char *ss)
{
int res=0;
int u=0,len=strlen(ss);
for(int i=0;i<len;++i)
{
int x=get(ss[i]);
u=ch[u][x];
if(rem[u]) res=max(res,dep[u]);
}
return res;
}
int main()
{
n=read(); m=read();
scanf("%s",&txt);
for(int i=1;i<=m;++i)
{
scanf("%s",&pt[i]);
ins(pt[i]);
}
ACM();
solve(txt);
for(int i=1;i<=m;++i)
printf("%d\n",query(pt[i]));
return 0;
}