P5357 【模板】AC自动机(二次加强版)
- 题意:求每个模式串在文本串中出现的次数,并按模式串输入的顺序输出。
思路:
我们以下面这个样例来讲解:
7
she
her
he
he
him
his
e
hisheheheheheher
我们可以得到这个样例的
Trie树
Trie图(只画出了用到的边)
Fail树
我们知道文本串的遍历是在Trie图上,通过fail指针的跳转来走的。现在我们要统计模式串出现的次数,如果我们仍然这样遍历的话,那么复杂度是极高的。就比如上面的样例,hisheheheheheher,如果he再多一些,那么就会在Trie图上的4->5->4循环很多次,fail指针也会一直在he和e中跳转。那当然这是不必要的,所以我们如何优化呢?
首先我们遍历一遍文本串,统计文本串在Tire图中对每个的结点经过次数。
然后,我们的fail指针可以构成一个Fail树,Fail树上每个结点都是一个前缀。而结点fail[ rt ]又是结点 rt 的最长后缀。所以前缀fail[ rt ]出现的次数中必然包括前缀 rt 出现的次数。所以我们构建出Fail树之后,dfs得到每个结点代表的前缀出现的次数即可。
那么答案就是前缀为完整模式串的出现次数咯~
AC CODE
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#define INF 0x3f3f3f3f
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxS = 2000006;
const int maxN = 200005 * 26;
struct AC_automat{
int trie[maxN][26], tot;
int fail[maxN];
int endd[maxS];//endd[i]:模式串i对应的结尾结点标号
int times[maxN];//times[i]:结尾结点标号为i的模式串出现的次数
void Clear(int rt)
{
for(int i = 0; i < 26; ++ i)
trie[rt][i] = 0;
times[rt] = 0;
}
void init()
{
Clear(0);
tot = 0;
memset(endd, 0, sizeof(endd));
}
void Insert(char *s, int p)
{
int rt = 0;
for(int i = 0; s[i]; ++ i)
{
int id = s[i] - 'a';
if(!trie[rt][id]) { trie[rt][id] = ++ tot; Clear(tot); }
rt = trie[rt][id];
}
endd[p] = rt;
}
void build()
{
memset(fail, 0, sizeof(fail));
queue<int>q;
for(int i = 0; i < 26; ++ i) if(trie[0][i]) q.push(trie[0][i]);
while(!q.empty())
{
int rt = q.front(); q.pop();
for(int i = 0; i < 26; ++ i)
{
if(trie[rt][i])
{
fail[trie[rt][i]] = trie[fail[rt]][i];
q.push(trie[rt][i]);
} else trie[rt][i] = trie[fail[rt]][i];
}
}
}
vector<int>vt[maxN];
void build_Fail(char * t)
{
int rt = 0;
for(int i = 0; t[i]; ++ i)
++ times[rt = trie[rt][t[i] - 'a']];
for(int i = 1; i <= tot; ++ i)
vt[fail[i]].push_back(i);
}
void dfs(int u)
{
int siz = vt[u].size();
for(int i = 0; i < siz; ++ i)
{
int v = vt[u][i];
dfs(v);
times[u] += times[v];
}
}
void print(int n)
{
for(int i = 0; i < n; ++ i)
printf("%d\n", times[endd[i]]);
}
}ac_auto;
int n; char s[maxS];
int main()
{
ac_auto.init();
scanf("%d", &n);
for(int i = 0; i < n; ++ i)
{
scanf("%s", s);
ac_auto.Insert(s, i);
}
ac_auto.build();
scanf("%s", s);
ac_auto.build_Fail(s);
ac_auto.dfs(0);
ac_auto.print(n);
return 0;
}
/*
7
she
her
he
he
him
his
e
hisheheheheheher
*/