题目描述
30% 字符串总长T≤5000
60% T≤200000
100% T≤500000
字符串只有字符集是{a,b}
分析
首先30暴力:直接跑kmp即可。
60分:考虑朴素根号算法,由于我比较辣鸡,比赛的时候没打玄学加成的算法,所以我的根号只有60。扯远了···不讲了好垃圾。
100分:
匹配嘛···先搞个SAM出来再说。
先来考虑如何解决询问。在最原始的题目里面,是查询[l=1,r=n]的串有没有Pi,这时我们可以把原串们弄成trie,再构建SAM,然后直接用Pi在SAM上沿正常边跑,就能找到包含Pi的SAM的状态。
那么进一步是要知道有多少个原串包含它。那我们对于每一个状态x都要维护col[x],pd[x][1~n],即状态x能识别的字符串,是多少个原串的子串,分别是哪些。
设trie上一个点y,它对应的原串集合为A={S[a1],S[a2]···},它对应的SAM的状态为last[y],则易知last[y]的原始状态:col[]=|A|,pd[][a1]=1,pd[][a2]=1···pd[][ak]=1。
由SAM的性质可知:在parent树中,状态x的R集合是以他为根的子树中所有点的R集合的并集。同理col和pd也是一样的。
那么我们pd用线段树维护,就可以处理区间查询;把询问离线挂在SAM的状态上,最后在parent树中从下到上线段树合并,碰到询问查询即可。
代码
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
const int N=500005;
int first[N*2],b[N*2],next[N*2],t1,fi1[N*2],n1[N*2],b1[N*2],fi2[N*2],bl[N*2],br[N*2],id[N*2],n2[N*2],t2,t3;
int pos,p,y,n,m,i,j,l,r,len,a[N],kan;
int sam[N*2][2],parent[N*2],mx[N*2],ts,tr[N*2][2],last[N*2],fa[N*2],pd[N*2],en[N*2],tt,prt[N*2];
int dl[N],q1,q2,st[N*2],x;
char s[N];
struct re
{
int l,r,t;
};
struct rec
{
re tr[N*10];
int size;
void insert(int x,int l,int r,int pos)
{
int m=(l+r)/2;
if (l==r)
{
tr[x].t=1;
return;
}
if (m>=pos)
{
if (!tr[x].l)
tr[x].l=++size;
insert(tr[x].l,l,m,pos);
}else
{
if (!tr[x].r)
tr[x].r=++size;
insert(tr[x].r,m+1,r,pos);
}
tr[x].t=tr[tr[x].l].t+tr[tr[x].r].t;
}
void comb(int x1,int x2,int l,int r)
{
if (l==r)
{
//printf("fail");
return;
}
int m=(l+r)/2;
if (!tr[x1].l||!tr[x2].l)
tr[x1].l=max(tr[x1].l,tr[x2].l);
else
comb(tr[x1].l,tr[x2].l,l,m);
if (!tr[x1].r||!tr[x2].r)
tr[x1].r=max(tr[x1].r,tr[x2].r);
else
comb(tr[x1].r,tr[x2].r,m+1,r);
tr[x1].t=tr[tr[x1].l].t+tr[tr[x1].r].t;
}
int get(int x,int l,int r,int i,int j)
{
int m=(l+r)/2;
if (l==i&&r==j)
return tr[x].t;
if (m>=j)
return get(tr[x].l,l,m,i,j);else
if (m<i)
return get(tr[x].r,m+1,r,i,j);
else
return get(tr[x].l,l,m,i,m)+get(tr[x].r,m+1,r,m+1,j);
}
}seg;
void cr1(int x,int y)
{
t2++;
b1[t2]=y;
n1[t2]=fi1[x];
fi1[x]=t2;
}
void cr2(int x,int y,int z,int ide)
{
t3++;
bl[t3]=y;
br[t3]=z;
id[t3]=ide;
n2[t3]=fi2[x];
fi2[x]=t3;
}
void cr(int x,int y)
{
t1++;
b[t1]=y;
next[t1]=first[x];
first[x]=t1;
}
void tradd(int x)
{
cr(x,i);
fo(l,1,len)
{
if (!tr[x][a[l]])
{
tr[x][a[l]]=++tt;
fa[tt]=x;
}
x=tr[x][a[l]];
cr(x,i);
}
en[i]=x;
}
int build(int last,int x)
{
int p=++ts,q=last,nq,np;
mx[p]=mx[q]+1;
for(;q!=-1&&(!sam[q][x]);q=parent[q]) sam[q][x]=p;
if (q==-1) parent[p]=0;else
{
nq=sam[q][x];
if (mx[nq]==mx[q]+1) parent[p]=nq;else
{
np=++ts;
mx[np]=mx[q]+1;
parent[np]=parent[nq];
parent[p]=parent[nq]=np;
sam[np][0]=sam[nq][0];
sam[np][1]=sam[nq][1];
for(;q!=-1&&sam[q][x]==nq;q=parent[q]) sam[q][x]=np;
}
}
return p;
}
void bfs(int x)
{
q1=0;
q2=1;
dl[1]=x;
last[x]=0;
while (q1<q2)
{
q1++;
x=dl[q1];
st[last[x]]=++seg.size;
for(int p=first[x];p;p=next[p])
seg.insert(st[last[x]],1,n,b[p]);
fo(j,0,1)
if (tr[x][j])
{
dl[++q2]=tr[x][j];
last[tr[x][j]]=build(last[x],j);
}
}
}
int find(int x)
{
fo(j,1,len)
{
x=sam[x][a[j]];
if (!x) return 0;
}
return x;
}
void bf(int x)
{
q1=0;q2=1;
dl[1]=x;
while (q1<q2)
{
x=dl[++q1];
for(int p=fi1[x];p;p=n1[p])
dl[++q2]=b1[p];
}
}
int main()
{
parent[0]=-1;
scanf("%d %d\n",&n,&m);
fo(i,1,n)
{
scanf("%s\n",s+1);
len=strlen(s+1);
fo(j,1,len) a[j]=s[j]-'a';
kan+=len;
tradd(0);
fo(j,1,len) s[j]=0;
}
bfs(0);
fo(i,1,ts) cr1(parent[i],i);
bf(0);
fo(i,1,m)
{
scanf("%d %d %s\n",&l,&r,s+1);
len=strlen(s+1);
fo(j,1,len) a[j]=s[j]-'a';
pos=find(0);
if (!pos) prt[i]=0;
else
{
cr2(pos,l,r,i);
fo(j,1,len) s[j]=0;
}
}
for(;q2;q2--)
{
x=dl[q2];
if (!st[x]) st[x]=++seg.size;
for(p=fi1[x];p;p=n1[p])
seg.comb(st[x],st[b1[p]],1,n);
for(p=fi2[x];p;p=n2[p])
prt[id[p]]=seg.get(st[x],1,n,bl[p],br[p]);
}
fo(i,1,m) printf("%d\n",prt[i]);
}