【模板】AC自动机(简单版)
题目链接:luogu P3808
题目大意
求一堆字符串中,有几个字符串在一个总字符串中出现过。
思路
这道题主要就是写 AC 自动机的模板题。
而 AC 自动机是什么呢?
其实就是把 KMP 和 Trie 树结合起来。
从而使其能够让 KMP 的功能变强,从只能匹配两个字符串,变成了能够匹配多个字符串跟一个字符串。
那怎么做呢?
KMP 中有一个 fail 数组,表示这个地方匹配不了的话可以直接去哪里继续匹配,从而减短用时。
那我们要让他在 Trie 树上做,就可以了。(因为你这个 Trie 树已经涵盖了多个字符串)
那怎么构建出 Trie 树上的 fail 数组呢?
我们考虑两种情况:
- 某个点 i i i 有下一个值是 k k k 的儿子点 j j j,那这个儿子匹配失败的时候就轮到匹配这个点 fail 值对应的点的儿子。也就是说: f a i l j = f a i l i . s o n k fail_j=fail_i.son_k failj=faili.sonk。
- 那如果没有这个儿子,那我们为了快一点,就直接让这个儿子变成到这个点支配之后的点的儿子即可。也就是说:
i
.
s
o
n
k
=
f
a
i
l
i
.
s
o
n
k
i.son_k=fail_i.son_k
i.sonk=faili.sonk。
(当然这个儿子实际上是不存在的)
那对于初始化,就是连向根节点的 fail 值都是根节点。
然后我们用队列来跑一个 bfs,就可以得出我们要的 fail 边,这个 Trie 树就变成了一个 Trie 图。
接着我们来看怎么找到答案。
首先,我们枚举结束的字符串的位置。
接着,我们就开始不停的跳 fail 边,然后每条一次(包括还没有跳的时候),我们就看一下当前这个字符串有没有出现过。
(至于重复的字符串,我们就把这写字符串看成一个,然后加的时候直接加个数即可)
然后因为你这样找的时候可能会找到很多次同一个字符串,导致计算重复(因为可能这个字符串在这个字符串里面不止出现了一次,而我们只是算的是有多少个字符串在这个大字符串中出现过,那无论这个字符串出现了多少次,我们都只算一次)。
那我们就可以统计一下这个字符串是否有被记录过,如果已经记录过,就可以不记录了。
最后输出一下结果,就可以了。
代码
#include<queue>
#include<cstdio>
#include<cstring>
using namespace std;
struct trie {
int son[31], num, fail;
bool use;
}tree[1000001];
int n, size, KK, now, thi, ans, noww;
char c[1000001];
void build() {//建Trie树
size = strlen(c);
now = 0;
for (int i = 0; i < size; i++) {
thi = c[i] - 'a';
if (!tree[now].son[thi]) tree[now].son[thi] = ++KK;
now = tree[now].son[thi];
}
tree[now].num++;
}
void get_fail() {//建fail边
queue <int> q;
for (int i = 0; i < 26; i++)
if (tree[0].son[i]) {
tree[tree[0].son[i]].fail = 0;
q.push(tree[0].son[i]);
}
while (!q.empty()) {
now = q.front();
q.pop();
for (int i = 0; i < 26; i++) {
if (tree[now].son[i]) {//有儿子
tree[tree[now].son[i]].fail = tree[tree[now].fail].son[i];
q.push(tree[now].son[i]);
}
else tree[now].son[i] = tree[tree[now].fail].son[i];//没有儿子
}
}
}
int getAC() {//算出是否出现
size = strlen(c);
now = 0;
ans = 0;
for (int i = 0; i < size; i++) {
thi = c[i] - 'a';
now = tree[now].son[thi];
noww = now;
while (noww && !tree[noww].use) {
tree[noww].use = 1;
ans += tree[noww].num;
noww = tree[noww].fail;
}
}
return ans;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) {
scanf("%s", &c);
build();
}
get_fail();
scanf("%s", &c);
printf("%d", getAC());
return 0;
}