【题目链接】
【题意】
给出一个字符串 S ( ∣ S ∣ ≤ 2 × 1 0 5 ) S(|S|\le 2\times 10^5) S(∣S∣≤2×105) ,有 q ( ≤ 2 × 1 0 5 ) q(\le 2\times 10^5) q(≤2×105) 组询问,每次询问取 S S S 的长度为 x x x 的前缀为 A A A ,取 S S S 的长度为 y y y 的后缀为 B B B ,询问字符串 A + B A+B A+B 在 S S S 中出现了多少次。
【思路】
看起来像是一个匹配问题。
考虑 KMP \text{KMP} KMP 的 f a i l fail fail 指针构成的树,树上的每个结点代表一个前缀,结点的父亲就是该前缀的最长公共前后缀。也就是说,该结点的父亲(乃至祖先),都在该结点代表的前缀串的尾部出现了一次匹配。不难发现,一个长度为 x x x 的前缀在该串中出现的所有匹配位置(尾部位置),就是其 f a i l fail fail 树上结点的子树所代表的位置集合。
对于后缀,我们考虑反串的的 f a i l fail fail 树,具有同样的性质。
我们发现,如果一组前缀和后缀拼起来在原串中的每次出现,必然对应着该前缀的结尾位置 x x x 和后缀的起始位置 x + 1 x+1 x+1 。因此不妨建立映射,让正串 f a i l fail fail 树和 反串 f a i l fail fail 树上的结点一一对应,那么每一组询问实际上就是问两棵树上的两棵子树中有多少对正好对应的结点数。
这就变成了经典的二维数点问题。
把询问离线,根据 x x x 把询问挂在正串的 f a i l fail fail 树上。然后遍历正串的 f a i l fail fail 树,对反串的 f a i l fail fail 树跑 d f s dfs dfs 序后用一个树状数组维护。每遇到正串的一个结点,就把对应的反串结点加入树状数组中。对于 x x x 上的询问,在进入 x x x 结点时查一次反串树上 y y y 的子树的区间求和,从 x x x 结点回溯之前再查一次反串树上 y y y 的子树的区间求和,两次结果作差就是询问的答案了。
复杂度 O ( n + q log n ) O(n+q\log n) O(n+qlogn)
【代码】
#include <vector>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define MAXN 210000
using namespace std;
int n,m;
char s[MAXN];
int nxt[MAXN];
int tot,pre[MAXN],lin[MAXN*2],to[MAXN*2],num;
int dfn[MAXN],out[MAXN];
vector<int> vec[MAXN];
void add(int x,int y)
{
tot++;lin[tot]=pre[x];pre[x]=tot;to[tot]=y;
}
struct BIT{
int q[MAXN];
BIT(){memset(q,0,sizeof(q));}
void add(int loc,int x){for(int i=loc;i<MAXN;i+=i&-i)q[i]+=x;}
int getsum(int x){int ret=0;for(int i=x;i;i-=i&-i)ret+=q[i];return ret;}
}bit;
struct Q{
int x,y,ans;
Q(){x=y=ans=0;}
}q[MAXN];
void KMP()
{
nxt[1]=0;
for(int i=2,j=0;i<=n;i++)
{
while(j&&s[i]!=s[j+1])j=nxt[j];
if(s[i]==s[j+1])j++;
nxt[i]=j;
}
}
void dfs(int x)
{
dfn[x]=++num;
for(int i=pre[x];i;i=lin[i])
{
int v=to[i];
dfs(v);
}
out[x]=num;
}
void dfs2(int x)
{
int aim=n-x;
for(auto &i:vec[x])
{
q[i].ans-=bit.getsum(out[q[i].y])-bit.getsum(dfn[q[i].y]-1);
}
bit.add(dfn[aim],1);
for(int i=pre[x];i;i=lin[i])
{
int v=to[i];
dfs2(v);
}
for(auto &i:vec[x])
{
q[i].ans+=bit.getsum(out[q[i].y])-bit.getsum(dfn[q[i].y]-1);
}
}
void init()
{
tot=num=0;
memset(pre,0,sizeof(pre));
for(int i=0;i<MAXN;i++)vec[i].clear();
memset(bit.q,0,sizeof(bit.q));
}
int main()
{
int T;
scanf("%d",&T);
while(T--)
{
init();
scanf("%d%d",&n,&m);
scanf("%s",s+1);
for(int i=1;i<=n/2;i++)swap(s[i],s[n-i+1]);
KMP();
for(int i=1;i<=n;i++)add(nxt[i],i);
dfs(0);
for(int i=1;i<=n/2;i++)swap(s[i],s[n-i+1]);
KMP();
tot=0;memset(pre,0,sizeof(pre));
for(int i=1;i<=n;i++)add(nxt[i],i);
for(int i=1;i<=m;i++)
{
scanf("%d%d",&q[i].x,&q[i].y);
q[i].ans=0;
vec[q[i].x].push_back(i);
}
dfs2(0);
for(int i=1;i<=m;i++)printf("%d\n",q[i].ans);
}
return 0;
}