题意:
给你一个字符串集合,包含n个字符串,然后q次查询,每次查询第x个字符串和第y个字符串的LCP(最长公共前缀)在整个集合中出现的次数。
分析:
我们把整个集合串起来,各个字符串连接处加特殊符号,然后求整个串的后缀数组,然后我们用st表维护heigh数组区间最小值。对于每一次查询,先求出第x个字符串和第y个字符串的LCP长度len,然后任选x串或者y串,取该串起始位置在整个连接串的位置p,记e=rk[p];然后只需要分别从e左边和右边找heigh>=len的最左位置和最右位置。这里是可以两个二分优化的,分别二分求e左边位置和右边位置,然后答案就是这个区间长度+1。
注意:
1、所求区间可能只包含<p或者大于p的位置,二分时要注意l和r的初始值。
2、x和y可能相等
3、看代码注释处。
Code:
#include<bits/stdc++.h>
#include<stdio.h>
#include<string.h>
#include<set>
#include<vector>
using namespace std;
#define ll long long
typedef unsigned long long ull;
const int Max = 1e6+20;
char s[Max],c[Max];
vector<char>Q[Max];
int cnt[Max];
int o[Max];//标记各串起始位置在连接串中的位置
ll h[Max];
int q[Max];
int sa[Max];
int id[Max],rk[Max<<1],odrk[Max<<1],px[Max];
int M;
int st[Max][21];
int logn[Max];
void pre(int n)
{
logn[1]=0;
logn[2]=1;
for(int i=3;i<=n;++i) logn[i]=logn[i/2]+1;
for(int i=1;i<=n;++i) st[i][0]=h[i];
for(int j=1;j<=21;++j)
{
for(int i=1;i+(1<<j)-1<=n;++i)
{
st[i][j]=min(st[i][j-1],st[i+(1<<(j-1))][j-1]);
}
}
}
bool cmp(int x,int y,int w)
{
return odrk[x]==odrk[y]&&odrk[x+w]==odrk[y+w];
}
void SA(int n)
{
memset(cnt,0,sizeof cnt);
int m=300;
for(int i=1; i<=n; i++) ++cnt[rk[i]=s[i]];
for(int i=1; i<=m; i++)cnt[i]+=cnt[i-1];
for(int i=n; i>=1; i--) sa[cnt[rk[i]]--]=i;
int w;
int p,i;
for(w=1;; w<<=1,m=p)
{
for(p=0,i=n; i>n-w; --i) id[++p]=i;
for(i=1; i<=n; ++i)
if(sa[i]>w) id[++p]=sa[i]-w;
memset(cnt,0,sizeof cnt);
for(i=1; i<=n; ++i) ++cnt[px[i]=rk[id[i]]];
for(i=1; i<=m; ++i) cnt[i]+=cnt[i-1];
for(i=n; i>=1; i--) sa[cnt[px[i]]--]=id[i];
memcpy(odrk,rk,sizeof rk);
for(p=0,i=1; i<=n; ++i)
{
rk[sa[i]]=cmp(sa[i],sa[i-1],w)?p:++p;
}
if(p==n)
{
for(i=1; i<=n; ++i)
{
sa[rk[i]]=i;
}
break;
}
}
int num=0;
for(int i=1; i<=n; i++)
{
int l=rk[i]-1;
if(num) --num;
while(s[i+num]==s[sa[l]+num]) ++num;
h[rk[i]]=num;
}
}
int main()
{
int N;
cin>>N;
int n = 0;
s[++n]='.';
for(int i=1;i<=N;++i)
{
scanf("%s",c+1);
int l = strlen(c+1);
o[i]=(n+1);
for(int j =1;j<=l;++j)
{
s[++n]=c[j];
Q[i].push_back(c[j]);
}
s[++n]='.';
}
SA(n);
pre(n);
int t;
scanf("%d",&t);
int x,y;
int len;
int ans1,ans2;
while(t--)
{
ans1=ans2=-1;
len = 0;
scanf("%d%d",&x,&y);
int e = min(rk[o[x]],rk[o[y]]),E=max(rk[o[x]],rk[o[y]]);
int p = logn[E-e];
len = min(st[e+1][p],st[E-(1<<p)+1][p]);
len = min(min((int)Q[x].size(),(int)Q[y].size()),len);
//这里注意,虽然得到了LCP但是因为所有字符串中间加的字符都是一样的可能会出现LCP大于长度的情况。
if(x==y) len = Q[x].size();
if(len == 0) {cout<<0<<endl;continue;}
int l,r;
if(h[e]>=len)
{
l=1,r=e;
while(l<=r)
{
int mid = (l+r)>>1;
int p = logn[e-mid+1];
if(min(st[mid][p],st[e-(1<<p)+1][p])>=len)
{
r = mid - 1;
ans1 = mid;
}
else l = mid+1;
}
}
else ans1 = e+1;
if(h[e+1]>=len)
{
l = e+1,r=n;
while(l<=r)
{
ll mid = (l+r)>>1;
int p = logn[mid-e];
if(min(st[e+1][p],st[mid-(1<<p)+1][p])>=len)
{
l = mid + 1;
ans2 = mid;
}
else r = mid - 1;
}
}
else ans2 = e;
printf("%d\n",ans2-ans1+2);
}
return 0;
}