比方说下面的这个Trie树,蓝色的箭头指向的就是其fail指针指向的节点。
我们不妨转换一下思路,对于每个x串,只有能通过fail指针指向它的末尾节点的y串节点才能计数。
那么我们不妨把fail指针反向,构建一棵fail树。
由于在一颗树中,一个节点及其子树在DFS序中是连续的一段,那么我们可以用一个树状数组来维护x串末尾节点及其子树上有多少个属于y串的节点。
那么我们可以得到一个离线算法:对fail树遍历一遍,得到一个DFS序,再维护一个树状数组,对原Trie树进行遍历,每访问一个节点,就修改树状数组,对树状数组中该节点的DFS序起点的位置加上1。每往回走一步,就减去1。如果访问到了一个y字串的末尾节点,枚举询问中每个y串对应的x串,查询树状数组中x串末尾节点从DFS序中的起始位置到结束位置的和,并记录答案。这样,我们就得到了一个时间复杂度为O(N+MlogN)的优美的算法。
代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
char s[maxn];
int head[maxn],h[maxn],cnt,cntt,root,m;
struct edge
{
int to,nxt;
}e[maxn<<1],ef[maxn<<1];
void add(int x,int y)
{
e[++cnt].to=y;
e[cnt].nxt=head[x];
head[x]=cnt;
}
void addd(int x,int y)
{
ef[++cntt].to=y;
ef[cntt].nxt=h[x];
h[x]=cntt;
}
int strnode[maxn],tot,str;
struct tree
{
int son[26],fail,f;
}tr[maxn];
int he[maxn],dfs_time,ta[maxn];
void dfs(int x)
{
he[x]=++dfs_time;
for(int i=h[x];i;i=ef[i].nxt)
dfs(ef[i].to);
ta[x]=dfs_time;
}
int ans[maxn],c[maxn];
int lowbit(int x)
{
return x&(-x);
}
void update(int x,int y)
{
while(x<=dfs_time)
{
c[x]+=y;
x+=lowbit(x);
}
}
int query(int x)
{
int res=0;
while(x)
{
res+=c[x];
x-=lowbit(x);
}
return res;
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%s",s);
int len=strlen(s),now=0,x,y;
scanf("%d",&m);
for(int i=0;i<m;i++)
{
scanf("%d%d",&x,&y);
add(y,x);
}
for(int i=0;i<len;i++)
{
if(s[i]=='P') strnode[++str]=now;
else if(s[i]=='B') now=tr[now].f;
else
{
int dig=s[i]-'a';
if(tr[now].son[dig]) now=tr[now].son[dig];
else
{
int cur=now;
tr[now=tr[now].son[dig]=++tot].f=cur;
}
}
}
queue <int> q;
for(int i=0;i<26;i++)
if(tr[0].son[i])
{
q.push(tr[0].son[i]);
addd(0,tr[0].son[i]);
tr[tr[0].son[i]].fail=0;
}
while(!q.empty())
{
int u=q.front();
q.pop();
for(int i=0;i<26;i++)
{
if(tr[u].son[i])
{
q.push(tr[u].son[i]);
for(now=tr[u].fail;now!=root && !tr[now].son[i];now=tr[now].fail);
tr[tr[u].son[i]].fail=tr[now].son[i]?tr[now].son[i]:root;
addd(tr[tr[u].son[i]].fail,tr[u].son[i]);
}
}
}
dfs(root);
now=root; str=0;
for(int i=0;i<len;i++)
{
if(s[i]=='B')
{
update(he[now],-1);
now=tr[now].f;
}
else if(s[i]!='P')
{
now=tr[now].son[s[i]-'a'];
update(he[now],1);
}
else
{
for(int x=head[++str];x;x=e[x].nxt)
ans[x]=query(ta[strnode[e[x].to]])-query(he[strnode[e[x].to]]-1);
}
}
for(int i=1;i<=m;i++) printf("%d\n",ans[i]);
return 0;
}