更好的阅读体验 Press Here
AC自动机 可食用最佳练手题
题意
有三种操作
* 在字符串的末尾加入小写字母
* 删除字符串末尾的字符
* 将当前字符串输出(不删除)
问第 x x 个输出的字符串 在第 个输出的字符串中出现了几次
Solution
当然直接暴力是没有问题的
存下第
x
x
个输出的字符串 然后用 KMP 优化匹配
时间复杂度可以达到 ,仍然会超时
既然需要对多个字符串进行匹配,自然的我们想到 AC自动机 (后缀自动机的泥奏凯)
发现如果第
x
x
个输出的字符串在第 个字符串的
n
n
位置匹配,那么第 个输出的字符串一定是第
y
y
个字符串的 位置的后缀,即可以通过
fail
f
a
i
l
指针转移到
那么原题就转换为了在第
y
y
个字符串中有多少个字符能够通过 转移到第
x
x
个字符串的结尾位置
但是直接做时间复杂度仍然很高
我们需要枚举每个 串,维护一个计数器,从根一路遍历到
y
y
串的末尾节点,途中对于每个节点,如果其 指针指向的是某
x
x
串的末尾节点,则累加
这个时间复杂度还是 ,无法通过此题
看到多个点通过
fail
f
a
i
l
上找一个点,为何不将其转化为一个点通过
fail
f
a
i
l
找其他点呢?
那么将
fail
f
a
i
l
反向,问题变为
x
x
串的结尾位置能够通过反向的 到达多少属于
y
y
的节点
考虑进一步优化,再这样一棵由反向 组成的树上,每个节点所能到达的点一定在它的子树之中,可以使用 DFS 序维护,同时用树状数组求答案
如何统计?
只需要重新按照原来输入再走一遍,每次遇到加入某个节点,则加入当前节点的 DFS 序;每次删除某个节点,则删除当前节点的 DFS 序;每次输出某个串,说明现在树状数组维护的就是这个串所有节点的 DFS 序,那么就可以统计这个串作为
y
y
<script type="math/tex" id="MathJax-Element-27">y</script> 串的答案了
详细代码如下:
#include <bits/stdc++.h>
using namespace std;
const int N = 100010;
int m , tot , l[N] , r[N] , t[N] , ans[N];
char s[N];
vector <int> e[N];
vector <pair<int , int>> q[N];
void add(int , int);
int query(int);
struct Aho {
int ch[N][26] , fail[N] , fa[N] , id[N] , go[N];
int cnt , total;
void build() {
int n = strlen(s) , now = 0;
for(int i = 0 ; i < n ; ++ i) {
if(s[i] == 'B') now = fa[now];
else if(s[i] == 'P') {id[++ total] = now; go[now] = total;}
else {
if(!ch[now][s[i] - 'a']) {
ch[now][s[i] - 'a'] = ++ cnt;
fa[cnt] = now;
}
now = ch[now][s[i] - 'a'];
}
}
queue <int> q;
for(int i = 0 ; i < 26 ; ++ i)
if(ch[0][i]) {
q.push(ch[0][i]);
e[0].push_back(ch[0][i]);
}
while(!q.empty()) {
int x = q.front(); q.pop();
for(int i = 0 ; i < 26 ; ++ i) {
if(ch[x][i]) {
q.push(ch[x][i]);
fail[ch[x][i]] = ch[fail[x]][i];
e[ch[fail[x]][i]].push_back(ch[x][i]);
}
else ch[x][i] = ch[fail[x]][i];
}
}
}
void get_ans() {
int n = strlen(s) , now = 0;
for(int i = 0 ; i < n ; ++ i) {
if(s[i] == 'B') {add(l[now] , -1); now = fa[now];}
else if(s[i] == 'P') {
for(auto j = q[go[now]].begin() ; j != q[go[now]].end() ; ++ j)
ans[j -> second] = query(r[id[j -> first]]) - query(l[id[j -> first]] - 1);
}
else {
now = ch[now][s[i] - 'a'];
add(l[now] , 1);
}
}
}
}ac;
int read() {
int ans = 0 , flag = 1;
char ch = getchar();
while(ch > '9' || ch < '0') {if(ch == '-') flag = -1; ch = getchar();}
while(ch >= '0' && ch <= '9') {ans = ans * 10 + ch - '0'; ch = getchar();}
return ans * flag;
}
int lowbit(int x) {return x & (-x);}
void add(int x , int y) {if(x) for(int i = x ; i <= tot ; i += lowbit(i)) t[i] += y;}
int query(int x) {int ans = 0; if(x) for(int i = x ; i ; i -= lowbit(i)) ans += t[i]; return ans;}
void dfs(int x) {
l[x] = ++ tot;
for(auto i = e[x].begin() ; i != e[x].end() ; ++ i) dfs(*i);
r[x] = tot;
}
void init() {
scanf("%s" , s);
m = read();
for(int i = 0 ; i < m ; ++ i) {
int x = read() , y = read();
q[y].push_back({x , i});
}
}
void work() {
ac.build();
dfs(0);
}
void print() {for(int i = 0 ; i < m ; ++ i) printf("%d\n" , ans[i]);}
int main() {
init();
work();
ac.get_ans();
print();
return 0;
}