题意
NOIP编号为ZJ-267的小Z在NOIP中AK啦!
小Z打算去冲击省选,于是开始学习trie。
有一天,他得到了N个字符串。
他先建立一个根节点,对于每一个字符串,他都从根节点开始一点点插入。
小Z不满足于此。他的大脑里盘旋着M个问题:
如果给定一个二元组(s,t)(s,t都是trie中的节点且s是t的祖先),
存在多少个二元组(x,y)(x,y都是trie中的节点且x是y的祖先),
满足s~t路径上的字符串和x~y路径上的字符串完全一样?
注意s可以等于t,x也可以等于y。
这里为了方便,读入仅仅是N个字符串,请自己建立出trie。
同时每一组询问的格式为(pi,xi,yi),
第pi个串的第xi个字符在trie中的位置即为s,
第pi个串的第yi个字符在trie中的位置即为t。
第一行为一个整数N。(N<=100000)
接下来N行,每行一个字符串。(保证字符串总长<=1000000)
接下来一行一个整数M。(M<=100000)
接下来M行,每行三个数pi,xi,yi意义如上。
(1<=pi<=N,1<=xi<=yi<=|Spi|)
分析
一开始想到了一个后缀数组+二分+主席树的做法,看了题解发现可以用后缀自动机来做。
我们可以在trie上建sam,记录每个位置在sam上对应的节点。对于一次询问,我们可以从深度较大的位置对应的节点开始,在parents树上往上跳,直到一个节点满足其区间包含询问字符串的长度。跳的过程可以用倍增来实现。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<vector>
using namespace std;
const int N=100005;
int n,m,ch[N*20][26],size[N*20],fa[N*20],bz[N*20][21],mx[N*20],b[N*20],c[N*20],trie[N*10][26],cnt,tot,ls[N*10];
vector<int> vec[N];
char str[N*10];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void ins(int &last,int x)
{
if (ch[last][x])
{
int p=last,np=ch[p][x];
if (mx[np]==mx[p]+1) last=np;
else
{
int q=++cnt;mx[q]=mx[p]+1;
memcpy(ch[q],ch[np],sizeof(ch[q]));
fa[q]=fa[np];
fa[np]=last=q;
for (;ch[p][x]==np;p=fa[p]) ch[p][x]=q;
}
size[last]++;
return;
}
int p,q,np,nq;
p=last;last=np=++cnt;mx[np]=mx[p]+1;size[np]=1;
for (;p&&!ch[p][x];p=fa[p]) ch[p][x]=np;
if (!p) fa[np]=1;
else
{
q=ch[p][x];
if (mx[q]==mx[p]+1) fa[np]=q;
else
{
nq=++cnt;mx[nq]=mx[p]+1;
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
memcpy(ch[nq],ch[q],sizeof(ch[nq]));
for (;ch[p][x]==q;p=fa[p]) ch[p][x]=nq;
}
}
}
void prework()
{
for (int i=1;i<=cnt;i++) b[mx[i]]++;
for (int i=1;i<=cnt;i++) b[i]+=b[i-1];
for (int i=cnt;i>=1;i--) c[b[mx[i]]--]=i;
for (int i=cnt;i>=1;i--) size[fa[c[i]]]+=size[c[i]];
for (int i=1;i<=cnt;i++)
{
int x=c[i];
bz[x][0]=fa[x];
for (int j=1;j<=20;j++) bz[x][j]=bz[bz[x][j-1]][j-1];
}
}
int main()
{
n=read();ls[0]=cnt=1;
for (int i=1;i<=n;i++)
{
scanf("%s",str);
int len=strlen(str);
int now=0;
for (int j=0;j<len;j++)
{
int x=str[j]-'a';
if (trie[now][x]) now=trie[now][x];
else trie[now][x]=++tot,ls[tot]=ls[now],ins(ls[tot],x),now=tot;
vec[i].push_back(ls[now]);
}
}
prework();
m=read();
while (m--)
{
int p=read(),x=read(),y=read();
int t=vec[p][y-1],len=y-x+1;
for (int i=20;i>=0;i--) if (mx[bz[t][i]]>=len) t=bz[t][i];
printf("%d\n",size[t]);
}
return 0;
}