Misha has a tree with characters written on the vertices. He can choose two vertices s and t of this tree and write down characters of vertices lying on a path from s to t. We'll say that such string corresponds to pair (s, t).
Misha has m queries of type: you are given 4 vertices a, b, c, d; you need to find the largest common prefix of the strings that correspond to pairs (a, b) and (c, d). Your task is to help him.
The first line contains integer n (1 ≤ n ≤ 300 000) — the number of vertices in the tree.
Next follows a line consisting of n small English letters. The i-th character of the string corresponds to the character written on the i-th vertex.
Next n - 1 lines contain information about edges. An edge is defined by a pair of integers u, v (1 ≤ u, v ≤ n, u ≠ v), separated by spaces.
The next line contains integer m (1 ≤ m ≤ 1 000 000) — the number of queries.
Next m lines contain information about queries. A query is defined by four integers a, b, c, d (1 ≤ a, b, c, d ≤ n), separated by spaces.
For each query print the length of the largest common prefix on a separate line.
6 bbbabb 2 1 3 2 4 3 5 2 6 5 6 2 5 3 1 1 5 2 3 5 6 5 6 6 3 4 1 6 2 3 4 2 2 4 5
2 2 2 0 1 0
大神的代码:
const int N=600005;
int next[N],node[N],head[N],e;
void add(int u,int v)
{
node[e]=v;
next[e]=head[u];
head[u]=e++;
}
int n;
char s[N];
int dep[N],sonNum[N],fa[N];
void dfs(int u,int pre)
{
dep[u]=dep[pre]+1;
fa[u]=pre;
sonNum[u]=1;
for(int i=head[u];i!=-1;i=next[i])
{
int v=node[i];
if(v!=pre)
{
dfs(v,u);
sonNum[u]+=sonNum[v];
}
}
}
int root[N],end[N];
void DFS(int u,int rt)
{
root[u]=rt;
end[rt]=u;
int s=0;
for(int i=head[u];i!=-1;i=next[i])
{
int v=node[i];
if(dep[v]>dep[u]&&sonNum[v]>sonNum[s])
{
s=v;
}
}
if(!s) return;
DFS(s,rt);
for(int i=head[u];i!=-1;i=next[i])
{
int v=node[i];
if(dep[v]>dep[u]&&v!=s)
{
DFS(v,v);
}
}
}
int S[N];
int sNum;
int p[N][4];
struct SufArr
{
int r[N],sa[N],wa[N],wb[N],wd[N],rank[N],h[N];
int cmp(int *r,int a,int b,int len)
{
return r[a]==r[b]&&r[a+len]==r[b+len];
}
void da(int *r,int *sa,int n,int m)
{
int i,j,p,*x=wa,*y=wb,*t;
for(int i=0;i<m;i++) wd[i]=0;
for(int i=0;i<n;i++) wd[x[i]=r[i]]++;
for(int i=1;i<=m-1;i++) wd[i]+=wd[i-1];
for(int i=n-1;i>=0;i--) sa[--wd[x[i]]]=i;
for(j=1,p=1;p<n;j<<=1,m=p)
{
p=0;
for(int i=n-j;i<=n-1;i++) y[p++]=i;
for(int i=0;i<n;i++) if(sa[i]>=j) y[p++]=sa[i]-j;
for(int i=0;i<m;i++) wd[i]=0;
for(int i=0;i<n;i++) wd[x[i]]++;
for(int i=1;i<=m-1;i++) wd[i]+=wd[i-1];
for(int i=n-1;i>=0;i--) sa[--wd[x[y[i]]]]=y[i];
t=x;x=y;y=t;p=1;x[sa[0]]=0;
for(int i=1;i<=n-1;i++) x[sa[i]]=cmp(y,sa[i-1],sa[i],j)?p-1:p++;
}
}
void calHeight(int *r,int *sa,int n)
{
int i,j,k=0;
for(int i=1;i<=n;i++) rank[sa[i]]=i;
for(int i=0;i<n;i++)
{
if(k) k--;
j=sa[rank[i]-1];
while(i+k<n&&j+k<n&&r[i+k]==r[j+k]) k++;
h[rank[i]]=k;
}
}
int f[N][20],n;
void init()
{
int i,j;
for(int i=1;i<=n;i++) f[i][0]=h[i];
for(i=1;1+(1<<i)<=n;i++) for(j=1;j+(1<<i)-1<=n;j++)
{
f[j][i]=min(f[j][i-1],f[j+(1<<(i-1))][i-1]);
}
}
int cal(int a,int b)
{
if(a==b) return n-a;
a=rank[a];
b=rank[b];
if(a>b) swap(a,b);
a++;
int m=floor(log(1.0*(b-a+1))/log(2.0));
return min(f[a][m],f[b-(1<<m)+1][m]);
}
/**store: r[0~len-1] and r[i]>0 for all i**/
void process(int *r,int len,int typeNum)
{
n=len;
r[len]=0;
da(r,sa,len+1,typeNum);
calHeight(r,sa,n);
init();
}
}A;
vector<pair<int,int> > init(int a,int b)
{
vector<pair<int,int> > pp,qq;
while(root[a]!=root[b])
{
if(dep[root[a]]>dep[root[b]])
{
int rt=root[a];
int y=p[rt][1];
int x=y-(dep[a]-dep[rt]);
pp.pb(MP(x,y));
a=fa[rt];
}
else
{
int rt=root[b];
int x=p[rt][2];
int y=x+(dep[b]-dep[rt]);
qq.pb(MP(x,y));
b=fa[rt];
}
}
if(dep[a]>dep[b])
{
int rt=root[a];
int e=p[rt][1];
int y=e-(dep[b]-dep[rt]);
int x=e-(dep[a]-dep[rt]);
pp.pb(MP(x,y));
}
else
{
int rt=root[a];
int e=p[rt][2];
int y=e+(dep[b]-dep[rt]);
int x=e+(dep[a]-dep[rt]);
pp.pb(MP(x,y));
}
for(int i=SZ(qq)-1;i>=0;i--)
{
pp.pb(qq[i]);
}
return pp;
}
int cal(int a,int b,int c,int d)
{
vector<pair<int,int> > p=init(a,b);
vector<pair<int,int> > q=init(c,d);
int ans=0;
int i=0,j=0;
while(i<SZ(p)&&j<SZ(q))
{
int len=A.cal(p[i].first,q[j].first);
int tmp=min(p[i].second-p[i].first+1,q[j].second-q[j].first+1);
int k=min(len,tmp);
ans+=k;
if(len<tmp) break;
if(k>=p[i].second-p[i].first+1) i++;
else p[i].first+=k;
if(k>=q[j].second-q[j].first+1) j++;
else q[j].first+=k;
}
return ans;
}
int main()
{
clr(head,-1);
n=myInt();
scanf("%s",s+1);
for(int i=1;i<n;i++)
{
int u=myInt();
int v=myInt();
add(u,v);
add(v,u);
}
dfs(1,0);
DFS(1,1);
sNum=-1;
for(int i=1;i<=n;i++) if(i==root[i])
{
p[i][0]=sNum+1;
for(int k=end[i];;k=fa[k])
{
S[++sNum]=s[k]-'a'+1;
if(k==i) break;
}
p[i][1]=sNum;
p[i][2]=sNum+1;
for(int k=sNum;k>=p[i][0];k--)
{
S[++sNum]=S[k];
}
p[i][3]=sNum;
}
sNum++;
A.process(S,sNum,30);
int Q=myInt();
while(Q--)
{
int a=myInt();
int b=myInt();
int c=myInt();
int d=myInt();
printf("%d\n",cal(a,b,c,d));
}
}
我对着抄了一遍。。。
#include<iostream>
#include<cstdio>
#include<string>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
#include<stack>
#include<map>
#include<set>
#include<algorithm>
using namespace std;
const int maxn=300010;
char str[maxn];
int N,Q;
int head[maxn],top[maxn],end[maxn],num[maxn],fa[maxn],son[maxn],w[maxn],deep[maxn];
int pos,tot;
char S[maxn*2];
struct node
{
int v,next;
}edge[maxn*2];
void init()
{
pos=tot=0;
memset(head,-1,sizeof(head));
memset(son,-1,sizeof(son));
}
void add_edge(int u,int v)
{
edge[tot].v=v;
edge[tot].next=head[u];
head[u]=tot++;
}
void dfs1(int u,int pre,int d)
{
deep[u]=d;
fa[u]=pre;
num[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==pre)continue;
dfs1(v,u,d+1);
num[u]+=num[v];
if(son[u]==-1||num[son[u]]<num[v])
son[u]=v;
}
}
void dfs2(int u,int sp)
{
top[u]=sp;
end[sp]=u;
w[u]=pos++;
if(son[u]==-1)return;
dfs2(son[u],sp);
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
}
struct SUFFIX
{
int n;
int d[maxn][20];
int sa[maxn],t[maxn],t2[maxn],height[maxn],rank[maxn],c[maxn];
void build_sa(char *s,int n,int m)
{
int *x=t,*y=t2;
for(int i=0;i<m;i++)c[i]=0;
for(int i=0;i<n;i++)c[x[i]=s[i]]++;
for(int i=1;i<m;i++)c[i]+=c[i-1];
for(int i=n-1;i>=0;i--)sa[--c[x[i]]]=i;
for(int k=1;k<=n;k<<=1)
{
int p=0;
for(int i=n-k;i<n;i++)y[p++]=i;
for(int i=0;i<n;i++)if(sa[i]>=k)y[p++]=sa[i]-k;
for(int i=0;i<m;i++)c[i]=0;
for(int i=0;i<n;i++)c[x[y[i]]]++;
for(int i=1;i<m;i++)c[i]+=c[i-1];
for(int i=n-1;i>=0;i--)sa[--c[x[y[i]]]]=y[i];
swap(x,y);p=1;
x[sa[0]]=0;
for(int i=1;i<n;i++)
x[sa[i]]=(y[sa[i]]==y[sa[i-1]]&&y[sa[i]+k]==y[sa[i-1]+k])?p-1:p++;
if(p>=n)break;
m=p;
}
}
void getheight(char *s,int n)
{
int k=0;
for(int i=1;i<=n;i++)rank[sa[i]]=i;
for(int i=0;i<n;i++)
{
if(k)k--;
int j=sa[rank[i]-1];
while(s[i+k]==s[j+k])k++;
height[rank[i]]=k;
}
}
void initRMQ(int n)
{
for(int i=0;i<=n;i++)d[i][0]=height[i];
for(int j=1;(1<<j)<=n;j++)
for(int i=1;i+(1<<(j-1))<=n;i++)
d[i][j]=min(d[i][j-1],d[i+(1<<(j-1))][j-1]);
}
void solve(char *s,int len)
{
s[len]=0;
n=len;
build_sa(s,len+1,30);
getheight(s,len);
initRMQ(n);
}
int LCP(int a,int b)
{
if(a==b)return n-a;
int x=rank[a],y=rank[b];
if(x>y)swap(x,y);
x++;
int k=0;
while((1<<(k+1))<(y-x+1))k++;
return min(d[x][k],d[y-(1<<k)+1][k]);
}
}A;
int p[maxn][4];
vector<pair<int,int> >initPATH(int a,int b)
{
vector<pair<int,int> > tmp1,tmp2;
int f1=top[a],f2=top[b];
while(f1!=f2)
{
if(deep[f1]>deep[f2])
{
int y=p[f1][1];
int x=y-(deep[a]-deep[f1]);
tmp1.push_back(make_pair(x,y));
a=fa[f1],f1=top[a];
}
else
{
int x=p[f2][2];
int y=x+(deep[b]-deep[f2]);
tmp2.push_back(make_pair(x,y));
b=fa[f1],f2=top[b];
}
}
if(deep[a]>deep[b])
{
f1=top[a];
int e=p[f1][1];
int y=e-(deep[b]-deep[f1]);
int x=e-(deep[a]-deep[f1]);
tmp1.push_back(make_pair(x,y));
}
else
{
f1=top[a];
int e=p[f1][2];
int y=e+(deep[b]-deep[f1]);
int x=e+(deep[a]-deep[f1]);
tmp1.push_back(make_pair(x,y));
}
for(int i=tmp2.size()-1;i>=0;i--)
tmp1.push_back(tmp2[i]);
return tmp1;
}
int cal(int a,int b,int c,int d)
{
vector<pair<int,int> > p=initPATH(a,b);
vector<pair<int,int> > q=initPATH(c,d);
int ans=0;
int i=0,j=0;
while(i<p.size()&&j<q.size())
{
int len=A.LCP(p[i].first,q[j].first);
int tmp=min(p[i].second-p[i].first+1,q[j].second-q[j].first+1);
int k=min(len,tmp);
ans+=k;
if(len<tmp)break;
if(k>=p[i].second-p[i].first+1)i++;
else p[i].first+=k;
if(k>=q[j].second-q[j].first+1)j++;
else q[j].first+=k;
}
return ans;
}
int main()
{
while(scanf("%d",&N)!=EOF)
{
init();
scanf("%s",str+1);
int u,v;
for(int i=1;i<N;i++)
{
scanf("%d%d",&u,&v);
add_edge(u,v);
add_edge(v,u);
}
dfs1(1,0,0);
dfs2(1,1);
int snum=-1;
for(int i=1;i<=N;i++)
{
if(i==top[i])
{
p[i][0]=snum+1;
for(int k=end[i];k!=i;k=fa[k])
{
S[++snum]=str[k]-'a'+1;
}
p[i][1]=snum;
p[i][2]=snum+1;
for(int k=snum;k>=p[i][0];k--)
S[++snum]=S[k];
p[i][3]=snum;
}
}
snum++;
A.solve(S,snum);
scanf("%d",&Q);
while(Q--)
{
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
printf("%d\n",cal(a,b,c,d));
}
}
return 0;
}