题目地址:
https://www.acwing.com/problem/content/description/1284/
给定 n n n个长度不超过 50 50 50的由小写英文字母组成的单词,以及一篇长为 m m m的文章。请问,其中有多少个单词在文章中出现了。注意:每个单词不论在文章中出现多少次,仅累计 1 1 1次。
输入格式:
第一行包含整数
T
T
T,表示共有
T
T
T组测试数据。
对于每组数据,第一行一个整数
n
n
n,接下去
n
n
n行表示
n
n
n个单词,最后一行输入一个字符串,表示文章。
输出格式:
对于每组数据,输出一个占一行的整数,表示有多少个单词在文章中出现。
数据范围:
1
≤
n
≤
1
0
4
1≤n≤10^4
1≤n≤104
1
≤
m
≤
1
0
6
1≤m≤10^6
1≤m≤106
思路是AC自动机。AC自动机的思想类似于Trie + KMP。先将所有单词建立成一个Trie,然后将其视为自动机,考虑在某个位置失配后应该跳到哪里继续匹配。如果树根的位置就失配了,显然只能继续从树根开始匹配;如果第 1 1 1层(即树根下一层)的节点失配了,也只能跳回树根继续匹配;考虑每个节点应该跳转到哪里继续匹配,如果某个节点代表的串是 s s s,那么如果在此处失配了,意味着在读入下一个字符的时候没有边可以继续向下走,由于我们已经知道了 s s s这个字符串这么长已经匹配成功,所以其后缀也能匹配,我们找到其最长后缀,并且使得这个后缀在Trie里存在,那么这个后缀对应的终点就是应该跳到的地方。
考虑怎么求,设 n [ v ] n[v] n[v]表示 v v v处失配时应该跳到哪里继续匹配。假设第 ≤ k \le k ≤k层的所有节点的失配之后跳转的位置都已经求出了,设 u u u是第 k + 1 k+1 k+1层的某个节点, p p p是其父亲,并且 p p p通过 c c c走到了 u u u,即 t [ p ] [ c ] = u t[p][c]=u t[p][c]=u(这个式子可以看成是一个自动机的转移,即 p p p位置读了 c c c转移到 u u u),我们即是要找 u u u的最大非平凡后缀,先看 n [ p ] [ c ] n[p][c] n[p][c]是否存在,如果存在,则 n [ u ] = t [ n [ p ] ] [ c ] n[u]=t[n[p]][c] n[u]=t[n[p]][c],如果不存在,则令 p 1 = n [ p ] p_1=n[p] p1=n[p],再看 n [ p 1 ] [ c ] n[p_1][c] n[p1][c]是否存在,如果存在,则 n [ u ] = t [ n [ p 1 ] ] [ c ] n[u]=t[n[p_1]][c] n[u]=t[n[p1]][c],以此类推,如果跳了若干次依然不存在,则到树根。这样在Trie的基础上,把 n n n数组求出来,这个数组又叫做next数组。
查询的时候,和KMP一样,顺着边向下走,如果走到 u u u继续走不动了,则走到 n [ u ] n[u] n[u]继续匹配;如果发现能匹配,则要将当前点 u u u的计数 c [ u ] c[u] c[u],以及 c [ n [ u ] ] , c [ n [ n [ u ] ] ] . . . c[n[u]],c[n[n[u]]]... c[n[u]],c[n[n[u]]]...全加上去。代码如下:
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1e4 + 10, S = 55, M = 1e6 + 10;
int n;
int tr[N * S][26], cnt[N * S], idx;
char s[M];
int q[N * S], ne[N * S];
void insert() {
int p = 0;
for (int i = 0; s[i]; i++) {
int t = s[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
int hh = 0, tt = 0;
// 树根和树根下一层的失配边都连到树根上,所以可以从树根下一层开始BFS
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[tt++] = tr[0][i];
while (hh < tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int c = tr[t][i];
if (!c) continue;
int j = ne[t];
// 一直跳到有i这条边为止
while (j && !tr[j][i]) j = ne[j];
// 如果有i这条边,则知道了j应该跳到tr[j][i],否则只能跳树根了
if (tr[j][i]) j = tr[j][i];
ne[c] = j;
// 继续遍历下一层
q[tt++] = c;
}
}
}
int query() {
int res = 0;
for (int i = 0, j = 0; s[i]; i++) {
int t = s[i] - 'a';
// 一直跳到有t这条边的节点处
while (j && !tr[j][t]) j = ne[j];
// 如果存在,则向下走一步,否则j回到树根
if (tr[j][t]) j = tr[j][t];
int p = j;
// 累加计数
while (p) {
res += cnt[p];
cnt[p] = 0;
p = ne[p];
}
}
return res;
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", s);
insert();
}
build();
scanf("%s", s);
printf("%d\n", query());
}
}
每组数据时空复杂度 O ( n ) O(n) O(n), n n n为Trie的节点总个数, h h h是Trie高度,即最长模式串长度。每次在匹配的时候,累加计数都需要跳若干次,最差跳 h h h步。
可以优化为Trie图,其实就是建Trie的时候,将每个节点的失配边连到它在原Trie中最终应该跳到的地方即可。设Trie中 t [ p ] [ i ] = u t[p][i]=u t[p][i]=u,如果 u ≠ 0 u\ne 0 u=0,那么和上面一样, n [ u ] = t [ n [ p ] ] [ i ] n[u] = t[n[p]][i] n[u]=t[n[p]][i];如果 u = 0 u=0 u=0,那么在点 p p p如果读到了 i i i边的字符,其实是发生了失配,由归纳假设, n [ p ] n[p] n[p]已经是 p p p失配的时候要跳到的最终位置,那么令 t [ p ] [ i ] = t [ n e [ p ] ] [ i ] t[p][i]=t[ne[p]][i] t[p][i]=t[ne[p]][i],这样即使遇到失配边也可以直接走,不用跳多步。此外,累加的时候也可以加优化,累加过的点标记一下。代码如下:
#include <iostream>
#include <cstring>
using namespace std;
const int N = 1e4 + 10, S = 55, M = 1e6 + 10;
int n;
int tr[N * S][26], cnt[N * S], idx;
char s[M];
int q[N * S], ne[N * S];
void insert() {
int p = 0;
for (int i = 0; s[i]; i++) {
int t = s[i] - 'a';
if (!tr[p][t]) tr[p][t] = ++idx;
p = tr[p][t];
}
cnt[p]++;
}
void build() {
int hh = 0, tt = 0;
for (int i = 0; i < 26; i++)
if (tr[0][i]) q[tt++] = tr[0][i];
while (hh < tt) {
int t = q[hh++];
for (int i = 0; i < 26; i++) {
int p = tr[t][i];
// 如果p不空,t读i不会失配,则p的ne指针应该连到
// 其父亲的ne指针的i边,和未优化的上面一样
if (p) ne[p] = tr[ne[t]][i], q[tt++] = p;
// 如果p空,则处于t读i边会失配,那么直接将最终跳的点连到下面即可
else tr[t][i] = tr[ne[t]][i];
}
}
}
int query() {
int res = 0;
for (int i = 0, j = 0; s[i]; i++) {
int t = s[i] - 'a';
// 不需要while循环了
j = tr[j][t];
int p = j;
// 另外一个优化,这里是优化累加的,把累加过的点标记一下
while (p && ~cnt[p]) {
res += cnt[p];
cnt[p] = -1;
p = ne[p];
}
}
return res;
}
int main() {
int T;
scanf("%d", &T);
while (T--) {
memset(tr, 0, sizeof tr);
memset(cnt, 0, sizeof cnt);
memset(ne, 0, sizeof ne);
idx = 0;
scanf("%d", &n);
for (int i = 0; i < n; i++) {
scanf("%s", s);
insert();
}
build();
scanf("%s", s);
printf("%d\n", query());
}
}
每组数据时空复杂度 O ( n ) O(n) O(n), n n n是Trie节点个数。