AC自动机
1. AC自动机原理
原理
KMP
- AC自动机是在KMP的基础上进行扩展,在KMP中,我们存在模式串p以及被匹配的串s,我们可以通过KMP算法在 O ( n ) O(n) O(n)的时间内判断p是否在s中出现过、出现的位置以及出现的次数。AC自动机实质上是将模式串p换成了trie树,为了理解AC自动机,我们需要深入理解KMP算法,关于KMP的算法原理如下:
- 这里分析一下一下next数组的求解代码
// 求next数组, ne[1]=0表示如果p[1]没有匹配上,从头开始匹配
for (int i = 2, j = 0; i <= n; i++) {
// 这里的j起始就是ne[i-1]
// 因为当i=2时,ne[2-1]=0,符合要求;之后每次循环最后ne[i]被赋值为j,然后i++, 因此ne[i - 1]就是j
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
根据上面代码注释中的解释,上述代码等价于:
// 求next数组, ne[1]=0表示如果p[1]没有匹配上,从头开始匹配
for (int i = 2; i <= n; i++) {
int j = ne[i - 1];
while (j && p[i] != p[j + 1]) j = ne[j];
if (p[i] == p[j + 1]) j++;
ne[i] = j;
}
因此上述代码就可以理解为,根据ne[0]~ne[i-1]
求解ne[i]
。
AC自动机
-
对应到AC自动机中,KMP中的模式串p需要变为一个trie树,我们其实也可以将KMP中的模式串p看成trie树,只不过这棵树是一个单链而已。我们需要在trie树中求解next数组。
-
类似于KMP中的next数组定义(next[j]=k表示后缀等于非平凡前缀的最大长度对应的下标),AC自动机中next数组定义:next[x]中存储的是trie中的某个节点y,节点y满足从根节点到y代表的字符串等于以节点x为结尾的等长的字符串,且该字符串是非平凡中最长的一个。
-
下面以
she、he、say、shr、her
这5个单词为例,讲解一下next数组的求解过程:(1)首先要建立trie树,如下图(假设代表单词的5个节点编号为1~5,实际trie树中不是1~5,这里为了讲解方便):
根据定义可知,上图中节点1对应的next值为5,即next[1]=5。
(2)类似于kmp中next[0]=next[1]=0,这里trie树中的第一层和第二层节点的next值也都为0,最终建立出来的trie树如下图:
-
对应到代码上,类似于KMP算法根据
ne[0]~ne[i-1]
求解ne[i]
,我们可以在这棵树上做BFS,根据前i-1层的结果来求解第i层的结果。
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = 0; i < 26; i ++ ) // 这里的i枚举的是字母
{
int c = tr[t][i]; // 字母i(0代表'a')对应的节点编号为c
if (!c) continue; // 说明从节点t不能走到字母i
int j = next[t];
while (j && !tr[j][i]) j = next[j]; // !tr[j][i]代表不能从节点j到达字母i
if (tr[j][i]) j = tr[j][i]; // 如果能达到,更新节点j对应的编号
next[c] = j;
q[++tt] = c;
}
}
代码对应的图示如下:
-
匹配过程类似于这里next的求解过程,这里省略。
-
AC自动机的时间复杂度也是线性的。
trie图
- 对AC自动机进行优化,可以得到trie图。思路是想要将最内层的while循环替换掉,优化一下常数。思想是非类似于路径压缩,因为while循环可能向上跳很多次,我们可以让它跳的时候一步到位,如下图:
// 此时,tr的定义被更新了,如果有i这个儿子,向下跳;否则不存在这个孩子,直接跳到next指针应该走到的位置
while (hh <= tt)
{
int t = q[hh ++ ];
for (int i = 0; i < 26; i ++ ) // 这里的i枚举的是字母
{
int &p = tr[t][i]; // 字母i(0代表'a')对应的节点编号为c
// 如果不存在到i的边,让节点p指向其父节点t的next指向的位置的第i个儿子
// 即此时tr[t][i]存储了next数组应该存储的内容
if (!p) p = tr[next[t]][i];
else {
next[p] = tr[next[t]][i];
q[++tt] = p;
}
}
}
2. AcWing上的AC自动机题目
AcWing 1282. 搜索关键词
问题描述
-
问题链接:AcWing 1282. 搜索关键词
分析
-
本题中的步骤是:
(1)将所有单词存入到trie树中;
(2)然后在trie树上求解next数组;
(3)匹配过程:trie树中的单词匹配文章。
-
对于第(3)步,我们需要注意,对于当前trie树中匹配到的字符串,需要将其最大后缀对应的字符串个数都加上(我们不需要考虑当前字符串是否为输入放入单词,因为不是单词的话,节点中对应的cnt值为0)。
代码
- C++
#include <iostream>
#include <cstring>
using namespace std;
const int N = 10010, S = 55, M = 1000010;
int n; // 单词数量
int tr[N * S][26];
int cnt[N * S]; // 以每个节点结尾的单词的数量
int idx;
char str[M]; // 读取输入字符串
int q[N * S]; // BFS求ne数组时的队列
int ne[N * S];
// trie中的插入函数
void insert() {
int p = 0; // 0既代表根节点,也代表空节点
for (int i = 0; str[i]; i++) {
int t = str[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
int hh = 0, tt = -1;
// 第一层、第二层对应的ne值都为0,直接将第二层入队即可
for (int i = 0; i < 26; i++)
if (tr[0][i]) // 根节点0存在孩子i
q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int c = tr[t][i];
if (!c) continue;
int j = ne[t];
while (j && !tr[j][i]) j = ne[j];
if (tr[j][i]) j = tr[j][i];
ne[c] = j;
q[++tt] = c;
}
}
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
// (1) 建立trie数
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", str);
insert();
}
// (2) 在trie树上求解next数组
build();
// (3) 匹配过程:trie树中的单词匹配文章
scanf("%s", str);
int res = 0; // 表示匹配的单词的数量
for (int i = 0, j = 0; str[i]; i++) { // 遍历文章中的每个字符
int t = str[i] - 'a';
while (j && !tr[j][t]) j = ne[j];
if (tr[j][t]) j = tr[j][t];
int p = j;
while (p) {
res += cnt[p];
cnt[p] = 0; // 该单词如果出现过,统一一遍即可
p = ne[p];
}
}
printf("%d\n", res);
}
return 0;
}
// trie图
#include <iostream>
#include <cstring>
using namespace std;
const int N = 10010, S = 55, M = 1000010;
int n; // 单词数量
int tr[N * S][26];
int cnt[N * S]; // 以每个节点结尾的单词的数量
int idx;
char str[M]; // 读取输入字符串
int q[N * S]; // BFS求ne数组时的队列
int ne[N * S];
// trie中的插入函数
void insert() {
int p = 0; // 0既代表根节点,也代表空节点
for (int i = 0; str[i]; i++) {
int t = str[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
int hh = 0, tt = -1;
// 第一层、第二层对应的ne值都为0,直接将第二层入队即可
for (int i = 0; i < 26; i++)
if (tr[0][i]) // 根节点0存在孩子i
q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int &p = tr[t][i];
if (!p) p = tr[ne[t]][i]; // 不存在到i的边
else {
ne[p] = tr[ne[t]][i];
q[++tt] = p;
}
}
}
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
// (1) 建立trie数
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", str);
insert();
}
// (2) 在trie树上求解next数组
build();
// (3) 匹配过程:trie树中的单词匹配文章
scanf("%s", str);
int res = 0; // 表示匹配的单词的数量
for (int i = 0, j = 0; str[i]; i++) { // 遍历文章中的每个字符
int t = str[i] - 'a';
j = tr[j][t];
int p = j;
while (p) {
res += cnt[p];
cnt[p] = 0; // 该单词如果出现过,统一一遍即可
p = ne[p];
}
}
printf("%d\n", res);
}
return 0;
}
AcWing 1285. 单词
问题描述
-
问题链接:AcWing 1285. 单词
分析
-
每一行都是一个单词,需要在所有给定的单词中进行匹配,可以认为模板串和待匹配的串是同一组数据。
-
对于所以输入的单词,建立一棵trie树,对于某个单词我们想要统计出其在其他所有(包含自己)串中出现的次数,我们应该怎么统计呢?
-
首先,我们要明确某个单词出现的次数一定小于等于所有字符串的总长度。
-
一个字符串出现的次数=所有满足要求的前缀个数,要求是这个前缀的的后缀等于原串。
-
这里采用另外一种思路:考虑每一个前缀t,t的后缀等于多少个前缀。相当于反过来考虑,这样一来我们可以迭代求解。
-
对于所有存在的边(i, next[i]),我们都连一条边,则我们会形成一个有向无环图,因为i所在的层一定比next[i]所在的层深。
-
上图中f的含义:
(1)建立trie后, f代表当前节点代表的字符串(必须从根节点开始形成的字符串)出现的次数;
(2)依据拓扑序进行递推即可,即
f[next[i]]+=f[i]
。递推之后, f代表当前节点代表的字符串在整个trie中出现的次数。这个递推过程可以参考上面的原理中的图,如下图(f[5]+=f[1]):
代码
- C++
#include <iostream>
using namespace std;
const int N = 1e6 + 10;
int n; // 单词个数
int tr[N][26], idx;
// 建立trie后, f代表当前节点代表的字符串(必须从根节点开始形成的字符串)出现的次数
// 递推之后, f代表当前节点代表的字符串在整个trie中出现的次数
int f[N];
int q[N]; // BFS求ne时使用到的队列
int ne[N];
char str[N]; // 输入字符串
int id[210]; // 每个单词在trie中对应节点的编号
void insert(int x) {
int p = 0;
for (int i =0; str[i]; i++) {
int t = str[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
f[p]++; // 每一个结束的位置都代表一个字符串
}
id[x] = p;
}
void build() {
int hh = 0, tt = -1;
for (int i = 0; i < 26; i++)
if (tr[0][i])
q[++tt] = tr[0][i];
while (hh <= tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int &p = tr[t][i];
if (!p) p = tr[ne[t]][i];
else {
ne[p] = tr[ne[t]][i];
q[++tt] = p;
}
}
}
}
int main() {
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", str);
insert(i); // i是当前单词对应编号
}
// 求解ne
build();
// 递推更新f, trie中节点编号为0~idx,一共idx+1个点,0既代表根节点又代表空节点
for (int i = idx; i; i--) f[ne[q[i]]] += f[q[i]];
for (int i = 0; i < n; i++) printf("%d\n", f[id[i]]);
return 0;
}