关于 AC自动机 及其 trie图优化 (KMP+trie+bfs)的一些体会
AC自动机本质上是在trie树上实现KMP思想
KMP 时间复杂度:O(n) 求出“某一个”单词 出现在哪些地方 出现次数(每次匹配一个串)
AC自动机 时间复杂度:O(n) 求出“每一个”单词 出现在哪些地方 出现次数(每次匹配一堆串)
关于next数组(核心):
KMP是在一个一维的模板串上建立一个next数组,而AC自动机是在二维的trie树上建立一个next数组。
KMP的next:存取从1到i中与 每一个以i结尾最长后缀匹配(匹配即相等)的最长前缀的长度,也就是最长前缀尾端点的下标(而且是非平凡前后缀,即前缀不包含最后一个字符 或 后缀不包含第一个字符)。
类比到AC自动机也一样,即trie树中的每一个节点都会存一个next[i]。
因此:
AC自动机的next:以 某个节点i 结尾 的所有 非平凡后缀中的某一个 和 trie树中某个非平凡前缀
匹配(匹配即相等)的话,存的是 最长前缀 的尾结点下标。
(如果不存在相同前缀则next[i]指向0,根节点和第一层所有节点的next均为0)
举个例子,比如在这个trie树中(虽然yxc画的有点乱,但是利于理解trie树上的next数组):
可以看到节点e的next指向的则为其上一层中的e节点(最长非平凡前后缀为he)
KMP求next代码:
void get_next()
{
for(int i=2;i<=n;++i)
{
int j=ne[i-1];
while(j&&(p[i]!=p[j+1])) j=ne[j];
if(p[i]==p[j+1]) j++;
ne[i]=j;
}
}
AC自动机求next思想与KMP类似,也是利用前i-1层节点信息求第i层节点的信息,因此在树上要一层层来做,显然要用宽搜。
即将KMP中的for循环扩展为一个宽搜的形式,一层层搜。
求next数组流程:(一维KMP 对应到 二维AC自动机)
while(hh<=tt)
{
①取出队头:t=q[hh++];//用前面的层算后面的,t即为KMP中的i-1
②遍历t所有儿子:
for(int i=0 i<26 ++i)
枚举所有字母,其实对应到KMP中就是p[i],只是和KMP不一样的是枚举p[i]的所有取值(for中的i对应KMP中的p[i])
{
c = trie[t,i] //t的第i个儿子,即KMP中的i
j = next[t] //对应KMP中的j=next[i-1];
接下来,如果在KMP中的话,就是判断p[i]与p[j+1](j下一位)是否相等,但是在trie树中,j下一位不止有一个字母,
因此我们就来判断j这个字母下一个位置(j的第i个位置,注意是a,b,c...z26个字母中的第i位,仔细体会)
是否存在p[i]这个字母。
即:while(j&&(!trie[j,i])) 则j=next[j];
if(trie[j,i])如果存在,则j=trie[j,i] 走到这个字母上去,对应KMP中的j++
之后,next[c]=j,对应KMP中的 next[i]=j
q[++tt]=c;//入队列
}
}
AC自动机求next代码(宽搜):
void bfs()
{
hh = 0, tt = -1;
for (int i = 0; i < 26; ++i)
{
if (trie[0][i]) q[++tt] = trie[0][i];
}
while (hh <= tt)
{
int t = q[hh++];
for (int i = 0; i < 26; ++i)
{
int c = trie[t][i];
if (!c) continue;
int j = ne[t];
while (j && (!trie[j][i])) j = ne[j];
if (trie[j][i]) j = trie[j][i];
ne[c] = j;
q[++tt] = c;
}
}
}
KMP和AC自动机匹配过程不再赘述,将AC自动机的匹配思想和KMP的匹配思想联系起来,与求next数组流程基本一致,细节见代码
#include<bits/stdc++.h>
using namespace std;
const int N = 55, M = 1e4 + 10;
const int S = 1e6 + 10;
int trie[N * M][26], cnt[N * M];//trie数组是字典树,存取了所有模板串(单词)
int ne[N * M];//trie树的next数组
int q[N * M], hh, tt;//广搜队列
int idx;
char s[S];//长文本串
void insert(char str[])//与trie模板插入操作一样
{
int p = 0;
for (int i = 0; str[i]; ++i)
{
int u = str[i] - 'a';
if (!trie[p][u]) trie[p][u] = ++idx;
p = trie[p][u];
}
cnt[p]++;
}
void bfs()
{
hh = 0, tt = -1;
for (int i = 0; i < 26; ++i)
//因为根节点和第一层节点的next指向的都是根节点(0),所以在宽搜时直接从第一层节点开始搜即可
{
if (trie[0][i]) q[++tt] = trie[0][i];//将第一层节点入队(存在才入队开始搜)
}
//接下来经典宽搜
while (hh <= tt)
{
int t = q[hh++];//取出队头用前面的层算后面的,t即为KMP中的i-1
for (int i = 0; i < 26; ++i)
{
int c = trie[t][i];//t的第i个儿子,即KMP中的i
if (!c) continue;//如果没有这个儿子,当然就没必要继续了,直接continue
int j = ne[t];//对应KMP中j=next[i-1]
while (j && (!trie[j][i])) j = ne[j];//如j的第i个位置不存在节点的话,j就往前挪,想KMP
if (trie[j][i]) j = trie[j][i];//如果有这个儿子,就走过去
ne[c] = j;//赶紧记录下来
q[++tt] = c;//入队列
}
}
}
int main()
{
cin.tie(0), ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
int n;
memset(trie, 0, sizeof trie);
memset(cnt, 0, sizeof cnt);
memset(s, '\0', sizeof s);
memset(ne, 0, sizeof ne);
idx = 0;
cin >> n;
for (int i = 0; i < n; ++i)
{
char word[M];
cin >> word;
insert(word);//将每一个单词插入到trie树
}
bfs();//其本质:求ne数组
int res = 0;
cin >> s;
for (int i = 0,j=0; s[i]; ++i)//j在for里
{
int t = s[i] - 'a';
while (j && (!trie[j][t])) j = ne[j];
//当j不存在t这个儿子,类似于KMP(s[i]!=p[j+1]),j=ne[j]进行回退
if (trie[j][t]) j = trie[j][t];//如果存在,j就走过去
//现在我们就找到了:当前字母s[i]能匹配到的trie树中“最深”的一个节点,即当前j的位置,首先j这个单
//词是必然出现过的(上方的“if (trie[j][t])”),但是还有一些单词是需要注意的,比如she出现过的话,
//其最长非平凡前缀he也必定出现过(he本来就是存在模板串中的单词,前者she可以匹配到,后者一样可以)
//因此在下方res累加答案时,不仅要加上j,同时还要加上j的ne[j]能走到的位置,即遍历所有next,把能加
//的都加上
int p = j;
while (p)
{
res += cnt[p];//累加答案(以p结尾单词数量)
cnt[p] = 0;/注意上方res加了以p结尾的单词数量,以后就不要再重复加了,要置为0
p = ne[p];//遍历next
}
}
cout << res << endl;
}
return 0;
}
无注释代码
#include<bits/stdc++.h>
using namespace std;
const int N = 55, M = 1e4 + 10;
const int S = 1e6 + 10;
int trie[N * M][26], cnt[N * M];
int ne[N * M];
int q[N * M], hh, tt;
int idx;
char s[S];
void insert(char str[])
{
int p = 0;
for (int i = 0; str[i]; ++i)
{
int u = str[i] - 'a';
if (!trie[p][u]) trie[p][u] = ++idx;
p = trie[p][u];
}
cnt[p]++;
}
void bfs()
{
hh = 0, tt = -1;
for (int i = 0; i < 26; ++i)
{
if (trie[0][i]) q[++tt] = trie[0][i];
}
while (hh <= tt)
{
int t = q[hh++];
for (int i = 0; i < 26; ++i)
{
int c = trie[t][i];
if (!c) continue;
int j = ne[t];
while (j && (!trie[j][i])) j = ne[j];
if (trie[j][i]) j = trie[j][i];
ne[c] = j;//
q[++tt] = c;
}
}
}
int main()
{
cin.tie(0), ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
int n;
memset(trie, 0, sizeof trie);
memset(cnt, 0, sizeof cnt);
memset(s, '\0', sizeof s);
memset(ne, 0, sizeof ne);
idx = 0;
cin >> n;
for (int i = 0; i < n; ++i)
{
char word[M];
cin >> word;
insert(word);
}
bfs();//其本质:求ne数组
int res = 0;
cin >> s;
for (int i = 0,j=0; s[i]; ++i)//
{
int t = s[i] - 'a';
while (j && (!trie[j][t])) j = ne[j];
if (trie[j][t]) j = trie[j][t];
int p = j;
while (p)
{
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
cout << res << endl;
}
return 0;
}
虽然这个算法时间复杂度是线性的,但是由于常数比较大,匹配时因为每次都要跳多次next指针进行回溯,复杂度上界可以达到 O(ml),我们可以想到,“如果失配时可以一步到位就好了。每次回溯的过程是固定的:一直跳,直到找到拥有儿子c的节点为止。因此无论什么时候在这个节点上失配,只要你找的是字符c,你总会在固定的节点上重新开始匹配。既然这样,不如直接把那个字符为c的节点变成自己的儿子,就可以省去回溯的麻烦” 。
链接:①AcWing 1282. AC自动机为何,如何优化成Trie图?
②AcWing 1282. trie图【形象】
因此对bfs算法进行优化:
while (hh <= tt)
{
int t = q[hh++];
for (int i = 0; i < 26; ++i)
{
int c = trie[t][i];
//对于两种情况我们要找的点都是trie[ne[t]][i]
if (!c)
{
trie[t][i]=trie[ne[t]][i];
//如果该儿子不存在则用前一层的结果(前面的层已经是正确的)使其指到父节点next指针对应的第i个儿子上去
}
else
{
//如果该儿子存在则记录ne[c]
ne[c]=trie[ne[t]][i];
q[++tt]=c;
}
//其正确性可以用数学归纳法
}
}
由于遍历到t点的时候t的儿子们的ne数组值已经更新过了,因此,必然可以一路递推到对应的子节点上
因为原本是DAG有向无环图结构的AC自动机出现了环,因此称为Trie图,此时可以做到真正的O(m)
匹配代码优化:
int res = 0;
cin >> s;
for (int i = 0,j=0; s[i]; ++i)
{
int t = s[i] - 'a';
j=trie[j][t];//经过bfs的优化后,匹配过程也得到了优化(省去了while,减小了常数)
int p = j;
while (p)
{
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
优化为trie图代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 55, M = 1e4 + 10;
const int S = 1e6 + 10;
int trie[N * M][26], cnt[N * M];
int ne[N * M];
int q[N * M], hh, tt;
int idx;
char s[S];
void insert(char str[])
{
int p = 0;
for (int i = 0; str[i]; ++i)
{
int u = str[i] - 'a';
if (!trie[p][u]) trie[p][u] = ++idx;
p = trie[p][u];
}
cnt[p]++;
}
void bfs()
{
hh = 0, tt = -1;
for (int i = 0; i < 26; ++i)
{
if (trie[0][i]) q[++tt] = trie[0][i];
}
while (hh <= tt)
{
int t = q[hh++];
for (int i = 0; i < 26; ++i)
{
int c = trie[t][i];
if (!c)
{
trie[t][i]=trie[ne[t]][i];
}
else
{
ne[c]=trie[ne[t]][i];
q[++tt]=c;
}
}
}
}
int main()
{
cin.tie(0), ios::sync_with_stdio(false);
int t;
cin >> t;
while (t--)
{
int n;
memset(trie, 0, sizeof trie);
memset(cnt, 0, sizeof cnt);
memset(s, '\0', sizeof s);
memset(ne, 0, sizeof ne);
idx = 0;
cin >> n;
for (int i = 0; i < n; ++i)
{
char word[M];
cin >> word;
insert(word);
}
bfs();//其本质:求ne数组
int res = 0;
cin >> s;
for (int i = 0,j=0; s[i]; ++i)
{
int t = s[i] - 'a';
j=trie[j][t];/
int p = j;/
while (p)
{
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
cout << res << endl;
}
return 0;
}