题目描述
传送门
题目大意:给出一个长度为n的字符串,每次询问[a,b]中的子串与[c,d]的最长公共前缀的长度的最大值
题解1:后缀数组
对字符串建立后缀数组。我们知道两个后缀的最长公共后缀等于区间[rank[i]+1,rank[j]]的height的最小值,那么因为是取min,所以区间的长度越长答案肯定不可能更优。
建立主席树,将i插入到rank[i]的位置,然后在主席树中维护每个区间的最靠左/最靠右的位置。
所以我们在[0,min(d-c+1,b-a+1)]中二分答案,可以确定一个起点的区间,然后查询主席树中这段区间[1,rank[c]]中最靠右的位置pos1,[rank[c],n]中最靠左的位置pos2,然后用st表维护区间height的最大值,查询[pos1+1,rank[c]],[rank[c]+1,pos2]的区间最小值,如果两个取max的结果大于当前二分的答案,那么就将l=mid+1,继续二分。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 100003
using namespace std;
int sa[N],height[N],st[20][N],xx[N],yy[N],*x,*y;
int n,m,a[N],v[N],rank[N],p,L[N],sz,root[N];
char s[N];
struct data{
int ls,rs,sum,px,py;
}tr[N*20];
int cmp(int i,int j,int l)
{
return y[i]==y[j]&&(i+l>n?-1:y[i+l])==(j+l>n?-1:y[j+l]);
}
void get_sa()
{
int m1=30; x=xx; y=yy;
for (int i=1;i<=n;i++) v[x[i]=a[i]]++;
for (int i=1;i<=m1;i++) v[i]+=v[i-1];
for (int i=n;i>=1;i--) sa[v[x[i]]--]=i;
for (int k=1;k<=n;k<<=1) {
p=0;
for (int i=n-k+1;i<=n;i++) y[++p]=i;
for (int i=1;i<=n;i++)
if (sa[i]>k) y[++p]=sa[i]-k;
for (int i=1;i<=m1;i++) v[i]=0;
for (int i=1;i<=n;i++) v[x[y[i]]]++;
for (int i=1;i<=m1;i++) v[i]+=v[i-1];
for (int i=n;i>=1;i--) sa[v[x[y[i]]]--]=y[i];
swap(x,y); p=2; x[sa[1]]=1;
for (int i=2;i<=n;i++)
x[sa[i]]=(cmp(sa[i],sa[i-1],k)?p-1:p++);
if (p>n) break;
m1=p+1;
}
for (int i=1;i<=n;i++) rank[sa[i]]=i;
p=0;
for (int i=1;i<=n;i++) {
if (rank[i]==1) continue;
int j=sa[rank[i]-1];
while (j+p<=n&&i+p<=n&&a[j+p]==a[i+p]) p++;
height[rank[i]]=p;
p=max(0,p-1);
}
for (int i=1;i<=n;i++) st[0][i]=height[i];
for (int i=1;i<=17;i++)
for (int j=1;j<=n;j++)
if (j+(1<<i)-1<=n)
st[i][j]=min(st[i-1][j],st[i-1][j+(1<<(i-1))]);
int j=0;
for (int i=1;i<=n;i++) {
if (1<<(j+1)<=i) j++;
L[i]=j;
}
}
void insert(int &i,int j,int l,int r,int x)
{
i=++sz; tr[i]=tr[j]; tr[i].sum++;
if (!tr[i].px) tr[i].px=x;
else tr[i].px=min(tr[i].px,x);
if (!tr[i].py) tr[i].py=x;
else tr[i].py=max(tr[i].py,x);
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid) insert(tr[i].ls,tr[j].ls,l,mid,x);
else insert(tr[i].rs,tr[j].rs,mid+1,r,x);
}
int query(int i,int j,int l,int r,int ll,int rr)
{
if (ll>rr) return 0;
if(l==r) {
if (tr[i].sum==0&&tr[j].sum) return l;
else return 0;
}
if (ll<=l&&r<=rr) {
if(tr[i].sum==0&&tr[j].sum) return tr[j].py;
}
int mid=(l+r)/2; int ans=0;
if (rr>mid&&tr[tr[j].rs].sum>tr[tr[i].rs].sum) ans=query(tr[i].rs,tr[j].rs,mid+1,r,ll,rr);
if (ll<=mid&&!ans) ans=max(ans,query(tr[i].ls,tr[j].ls,l,mid,ll,rr));
return ans;
}
int query1(int i,int j,int l,int r,int ll,int rr)
{
if (ll>rr) return N+1;
if(l==r) {
if (tr[i].sum==0&&tr[j].sum) return l;
else return N+1;
}
if (ll<=l&&r<=rr) {
if (tr[i].sum==0&&tr[j].sum) return tr[j].px;
}
int mid=(l+r)/2; int ans=N+1;
if (ll<=mid&&tr[tr[j].ls].sum>tr[tr[i].ls].sum) ans=query1(tr[i].ls,tr[j].ls,l,mid,ll,rr);
if (rr>mid&&ans==N+1) ans=min(ans,query1(tr[i].rs,tr[j].rs,mid+1,r,ll,rr));
return ans;
}
int calc(int x,int y)
{
int k=L[y-x];
return min(st[k][x],st[k][y-(1<<k)+1]);
}
bool check(int x,int a,int b,int c)
{
if (x==0) return true;
b=b-x+1; int ans=0; int val=rank[c];
if (b<a) return false;
int pos1=query(root[a-1],root[b],1,n,1,val);
int pos2=query1(root[a-1],root[b],1,n,val,n);
if (pos1!=0)
if (pos1+1<=val)ans=max(ans,calc(pos1+1,val));
else ans=max(ans,n-sa[pos1]+1);
if (pos2!=N+1)
if (val+1<=pos2) ans=max(ans,calc(val+1,pos2));
else ans=max(ans,n-sa[pos2]+1);
return ans>=x;
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m);
scanf("%s",s+1);
for (int i=1;i<=n;i++) a[i]=s[i]-'a'+1;
get_sa();
for (int i=1;i<=n;i++) insert(root[i],root[i-1],1,n,rank[i]);
for (int i=1;i<=m;i++) {
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
if (a==c&&b==d) {
printf("%d\n",b-a+1);
continue;
}
int l=0; int r=min(b-a+1,d-c+1); int ans=0;
while (l<=r) {
int mid=(l+r)/2;
if(check(mid,a,b,c)) l=mid+1,ans=max(ans,mid);
else r=mid-1;
}
printf("%d\n",ans);
}
}
题解2:后缀自动机
将字符串翻转,那么求最长公共前缀,就变成了求最长公共后缀。
两个点的最长公共后缀是两个点lca的len.
我们对于parent树中的每个节点,都维护他的子树中出现了字符串中的哪些节点。这个可用线段树合并。
每次找到d在树中对应的位置,二分答案,找到d的祖先中len>=x的最深的节点,判断该节点中是否出现了[a+x-1,b]。
代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define N 200003
using namespace std;
struct data{
int sum,ls,rs;
}tr[N*30];
int n,m,point[N],nxt[N],v[N],p,q,np,nq,mp[N],pos[N],cnt,last;
int root[N],rt,ch[N][27],fa[N],l[N],tot,sz,f[N][19],deep[N],mi[20];
char s[N];
void extend(int x)
{
int c=s[x]-'a'+1;
p=last; np=++cnt; last=np; mp[np]=x; pos[x]=np;
l[np]=l[p]+1;
for (;p&&!ch[p][c];p=fa[p]) ch[p][c]=np;
if (!p) fa[np]=rt;
else {
q=ch[p][c];
if (l[p]+1==l[q]) fa[np]=q;
else {
nq=++cnt; l[nq]=l[p]+1;
memcpy(ch[nq],ch[q],sizeof(ch[q]));
fa[nq]=fa[q];
fa[q]=fa[np]=nq;
for (;ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
}
}
}
void add(int x,int y)
{
tot++; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
//cout<<x<<" "<<y<<" "<<mp[y]<<endl;
}
void insert(int &i,int l,int r,int x)
{
i=++sz; tr[i].sum++;
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid) insert(tr[i].ls,l,mid,x);
else insert(tr[i].rs,mid+1,r,x);
}
void update(int x)
{
int l=tr[x].ls; int r=tr[x].rs;
tr[x].sum=tr[l].sum+tr[r].sum;
}
int merge(int x,int y)
{
if (!x) return y;
if (!y) return x;
int i=++sz;
tr[i].ls=merge(tr[x].ls,tr[y].ls);
tr[i].rs=merge(tr[x].rs,tr[y].rs);
update(i);
return i;
}
void dfs(int x)
{
for (int i=1;i<=17;i++) {
if (deep[x]-mi[i]<0) break;
f[x][i]=f[f[x][i-1]][i-1];
}
for (int i=point[x];i;i=nxt[i]) {
f[v[i]][0]=x; deep[v[i]]=deep[x]+1;
dfs(v[i]);
root[x]=merge(root[x],root[v[i]]);
}
}
int find(int i,int l,int r,int ll,int rr)
{
if (ll<=l&&r<=rr) return tr[i].sum;
int mid=(l+r)/2;
int ans=0;
if (ll<=mid) ans+=find(tr[i].ls,l,mid,ll,rr);
if (rr>mid) ans+=find(tr[i].rs,mid+1,r,ll,rr);
return ans;
}
bool check(int k,int a,int b,int x)
{
if (k==0) return true;
if (a>b) return false;
for (int i=17;i>=0;i--)
if(l[f[x][i]]>=k)
x=f[x][i];
return find(root[x],1,n,a,b);
}
int main()
{
freopen("a.in","r",stdin);
freopen("my.out","w",stdout);
scanf("%d%d",&n,&m); mi[0]=1;
for (int i=1;i<=18;i++) mi[i]=mi[i-1]*2;
scanf("%s",s+1); reverse(s+1,s+n+1);
last=rt=++cnt;
for (int i=1;i<=n;i++) extend(i);
for (int i=1;i<=cnt;i++) add(fa[i],i);
for (int i=1;i<=cnt;i++)
if (mp[i]) insert(root[i],1,n,mp[i]);
deep[1]=1; dfs(1);
for (int i=1;i<=m;i++) {
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
swap(a,b); swap(c,d);
a=n-a+1; b=n-b+1; c=n-c+1; d=n-d+1;
int l=0; int r=min(d-c+1,b-a+1); int ans=0;
while (l<=r) {
int mid=(l+r)/2;
if (check(mid,a+mid-1,b,pos[d])) ans=max(ans,mid),l=mid+1;
else r=mid-1;
}
printf("%d\n",ans);
}
}