Description
小Z打算去冲击省选,于是开始学习trie。
有一天,他得到了N个字符串。
他先建立一个根节点,对于每一个字符串,他都从根节点开始一点点插入。
小Z不满足于此。他的大脑里盘旋着M个问题:
如果给定一个二元组(s,t)(s,t都是trie中的节点且s是t的祖先),
存在多少个二元组(x,y)(x,y都是trie中的节点且x是y的祖先),
满足s~t路径上的字符串和x~y路径上的字符串完全一样?
注意s可以等于t,x也可以等于y。
这里为了方便,读入仅仅是N个字符串,请自己建立出trie。
同时每一组询问的格式为(pi,xi,yi),
第pi个串的第xi个字符在trie中的位置即为s,
第pi个串的第yi个字符在trie中的位置即为t。
Data Constraint
N≤100000 字符串总长≤1000000 (1≤pi≤N,1≤xi≤yi≤|Spi|)
字符均为小写字母
分析
题目就是求一个字符串在trie上的出现次数。
求子串出现次数,很容易想到用SAM!
先构出trie,然后构出SAM。对于一个询问,我们从位置t在SAM上所在的位置开始沿parent树往上跳,直到一个节点,再往上跳的Max会小于字符串长度,然后该节点的right集合大小即为答案。(该节点及其在parent树上的子节点能表示的字符串的后缀中一定有给定的字符串,而继续往上跳会把其它子串算进答案)。
注意树的深度可能很大,所以要打非递归。
代码处理询问部分写烂了,跑得较慢
C++
2625 ms
369028 KB
Accepted
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=100005,maxm=1000005,maxt=2100005,N=1e6,maxM=2000005;
int n,m,tot,cnt,st[maxn],id[maxm],pos[maxm],data[maxM],ans[maxn],t[maxt],a[maxM*2],fi[maxM],la[maxM],sum[maxM];
int Parent[maxM],Max[maxM],size[maxM],h[maxM],e[maxM],next[maxM],H[maxM],E[maxn],Next[maxn];
char S[maxm];
int e1[maxm][26],e2[maxM][26];
void add(int x,int y)
{
e[++cnt]=y; next[cnt]=h[x]; h[x]=cnt;
}
void Add(int x,int y)
{
E[++cnt]=y; Next[cnt]=H[x]; H[x]=cnt;
}
int insert(int last,int c)
{
int np=++tot,p,q,nq; sum[np]=1;
Max[np]=Max[last]+1;
for (p=last;p>=0 && !e2[p][c];p=Parent[p]) e2[p][c]=np;
if (p<0) Parent[np]=0;else
{
q=e2[p][c];
if (Max[q]==Max[p]+1) Parent[np]=q;else
{
nq=++tot;
memcpy(e2[nq],e2[q],sizeof(e2[q]));
Parent[nq]=Parent[q]; Max[nq]=Max[p]+1;
Parent[np]=Parent[q]=nq;
for (;p>=0 && e2[p][c]==q;p=Parent[p]) e2[p][c]=nq;
}
}
return np;
}
void change(int l,int r,int g,int v,int x)
{
if (l==r)
{
t[x]=v; return;
}
int mid=(l+r)/2;
if (g<=mid) change(l,mid,g,v,x*2);else change(mid+1,r,g,v,x*2+1);
t[x]=(t[x*2])?t[x*2]:t[x*2+1];
}
int get(int l,int r,int g,int x)
{
if (l==g) return t[x];
int mid=(l+r)/2;
if (mid<g) return get(mid+1,r,g,x*2+1);
int tmp=get(l,mid,g,x*2);
return (tmp)?tmp:t[x*2+1];
}
int main()
{
scanf("%d",&n);
for (int i=0;i<n;i++)
{
scanf("%s",&S);
int len=strlen(S),x,j;
st[i+1]=st[i]+len;
for (x=j=0;j<len;j++)
{
if (!e1[x][S[j]-'a']) e1[x][S[j]-'a']=++tot;;
x=e1[x][S[j]-'a'];
id[st[i]+j]=x;
}
}
Parent[0]=-1;
data[cnt=1]=tot=0;
for (int i=1;i<=cnt;i++)
{
int x=data[i];
for (int j=0;j<26;j++) if (e1[x][j])
{
data[++cnt]=e1[x][j];
pos[data[cnt]]=insert(pos[x],j);
}
}
cnt=0;
for (int i=1;i<=tot;i++) add(Parent[i],i);
data[cnt=1]=0;
for (int i=1;i<=cnt;i++)
{
int x=data[i];
for (int j=h[x];j;j=next[j]) data[++cnt]=e[j];
}
for (int i=cnt;i>1;i--)
{
size[Parent[data[i]]]+=size[data[i]]+1;
sum[Parent[data[i]]]+=sum[data[i]];
}
fi[0]=1; la[0]=size[0]*2+2;
for (int i=1;i<=cnt;i++)
{
int x=data[i],last=fi[x];
for (int j=h[x];j;j=next[j])
{
fi[e[j]]=last+1; la[e[j]]=fi[e[j]]+size[e[j]]*2+1;
a[fi[e[j]]]=a[la[e[j]]]=e[j];
last=la[e[j]];
}
}
cnt=0;
scanf("%d",&m);
while (m--)
{
int p,x,y;
scanf("%d%d%d",&p,&x,&y);
Add(pos[id[st[p-1]+y-1]],y-x+1);
}
for (int i=1;i<=la[0];i++)
if (i==fi[a[i]])
{
change(0,N,Max[a[i]],a[i],1);
for (int j=H[a[i]];j;j=Next[j]) ans[j]=sum[get(0,N,E[j],1)];
}else change(0,N,Max[a[i]],0,1);
for (int i=1;i<=cnt;i++) printf("%d\n",ans[i]);
return 0;
}