问题:有多个单词,然后给出一个较长的字符串,问有多少个单词出现在这个字符串里面。这里要用到的就是AC自动机,AC自动机需要KMP和字典树知识,最好把这两样都弄懂了。
AC自动机其实就是字典树里面再加一个fail指针,这个fail指针的用处就是当当前字符无法被匹配时所要进行的转移,很想KMP算法里面的next数组,比如对于单词:qwert, wert, ert所构成的字典树如下:(世上最丑作图。。丑出境界了哈哈哈)
我自己都要看吐了。。
然后加了fail指针之后:
再一次丑出天际。
(PS:其实这里在真正建fail指针的时候是有一点不同的,因为字典树要从a扫到z所以会先扫到e再扫到q再到w,不过其实都是一样的)
还是老套路,结构体:
const int tp = 27;
struct node {
node *next[tp]; //字符集个数
node *fail; // fail指针,AC自动机最重要的东西
bool ifc; // 是否是单词结尾
};
然后建立插入单词其实是跟字典树的一样的,这里就不多赘述了,具体可以看一下本人博客里字典树的代码注释:
node *getNewNode() {
node *p = new node;
for (int i = 0; i < tp; ++i) p->next[i] = NULL;
p->fail = NULL;
p->ifc = false;
return p;
}
void Insert(char *s) {
int len = strlen(s);
node *p = root;
for (int i = 0; i < len; ++i) {
int idx = s[i] - 'a';
if (p->next[idx] == NULL) p->next[idx] = getNewNode();
p = p->next[idx];
}
p->ifc = true;
}
插入完单词以后就是最重要的构建fail指针了,说白了跟KMP的next指针构建原理很像,再看这张丑到爆的图
比如说qwert,这个单词里面的e,他的fail指针,以e结尾的子串是qwe,这个子串的后缀为we 和 e,fail指针指向的就是某个以we或者e为前缀的单词里面的那个e(说的我自己都要晕了),如果不存在的话就指向root,还有root的儿子节点的fail指针都指向root。
看图 qwe中we是qwe的后缀,也是wert的前缀,所以e指向wert的e,就是这么个情况。建立fail指针的时候用到的是BFS,需要一层一层的建立,因为子节点的fail指针需要父节点的fail指针。
struct node *root;
struct node *que[maxn]; // 队列
void buildACnode() {
int l = 0, r = 1;
que[0] = root;
while (l < r) { // 用数组模拟队列
node *tmp = que[l++];
if (tmp == root) { // root儿子节点的fail指向root
for (int i = 0; i < tp; ++i)
if (tmp->next[i]) {
tmp->next[i]->fail = root;
que[r++] = tmp->next[i]; // 别忘了加入队列中
}
} else {
for (int i = 0; i < tp; ++i) {
if (tmp->next[i]) {
node *p = tmp->fail;
while (p && !p->next[i]) p = p->fail;//看父节点的fail指针指向的节点是否有i这个节点,直到p为空或者找到符合的节点
if (!p) tmp->next[i]->fail = root;//没有符合的节点,指向root
else tmp->next[i]->fail = p->next[i];
que[r++] = tmp->next[i];
}
}
}
}
}
最后就剩下查询了:
查询其实就是顺着字典树往下走,如果匹配这个节点那么接着往下走,否则跳转到这个节点的fail指针继续寻找是否有匹配节点
int query(char *str) {
int cnt = 0, len = strlen(str);
node *p = root;
for (int i = 0; i < len; ++i) {
int id = str[i] - 'a';
while (!p->next[id] && p != root) p = p->fail; //若当前没有能够匹配的节点则跳转到该节点的fail指针继续检测
p = p->next[id];
if (!p) p = root; // 遍历完所有fail指针都没有那么返回到root,已经不可能继续匹配这个单词了
node *t = p;
while (t != root) { // 遍历这个节点的所有直接间接相连的fail指针,看有没有某个节点是单词结束字符
if (t->ifc) {
cnt++;
// t->ifc = false;
}// else break;
t = t->fail;
}
}
return cnt;
}
然后就大功告成了!
总代码:
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int tp = 27;
const int maxn = 100000 + 7;
struct node {
node *next[tp];
node *fail;
bool ifc;
};
struct node *root;
struct node *que[maxn];
node *getNewNode() {
node *p = new node;
for (int i = 0; i < tp; ++i) p->next[i] = NULL;
p->fail = NULL;
p->ifc = false;
return p;
}
void Insert(char *s) {
int len = strlen(s);
node *p = root;
for (int i = 0; i < len; ++i) {
int idx = s[i] - 'a';
if (p->next[idx] == NULL) p->next[idx] = getNewNode();
p = p->next[idx];
}
p->ifc = true;
}
void buildACnode() {
int l = 0, r = 1;
que[0] = root;
while (l < r) {
node *tmp = que[l++];
if (tmp == root) {
for (int i = 0; i < tp; ++i)
if (tmp->next[i]) {
tmp->next[i]->fail = root;
que[r++] = tmp->next[i];
}
} else {
for (int i = 0; i < tp; ++i) {
if (tmp->next[i]) {
node *p = tmp->fail;
while (p && !p->next[i]) p = p->fail;
if (!p) tmp->next[i]->fail = root;
else tmp->next[i]->fail = p->next[i];
que[r++] = tmp->next[i];
}
}
}
}
}
int query(char *str) {
int cnt = 0, len = strlen(str);
node *p = root;
for (int i = 0; i < len; ++i) {
int id = str[i] - 'a';
while (!p->next[id] && p != root) p = p->fail;
p = p->next[id];
if (!p) p = root;
node *t = p;
while (t != root) {
if (t->ifc) {
cnt++;
// t->ifc = false;
}// else break;
t = t->fail;
}
}
return cnt;
}
char s[1000];
int main() {
root = getNewNode();
int n;
scanf("%d", &n);
for (int i = 0; i < n; ++i) {
scanf("%s", s);
Insert(s);
}
buildACnode();
scanf("%s", s);
int res = query(s);
printf("answer:%d\n", res);
return 0;
}
/**
3
qwert
wert
ert
qwerty
**/