Problem
给定一颗N个点的树,有Q次询问,每次询问u到v的最短路径编号序列中,最长的先增后减或先减后增的连续子序列长度为多少。询问在线。
Hint
对于 20%数据, N,Q≤100
对于 40%数据, N,Q≤20000
对于 70%数据, N,Q≤50000
对于 100%数据, N,Q≤100000
Solution
这道题我刚看感觉极其复杂,而且还强制在线,实在不是人能做的。我的思路错误地囿于链剖,而且我想到的方法又极其繁琐,所以我还剩1hour的时候就毫不犹豫地选择了暴力。
事实上由于本题没有修改操作,完全可以不用链剖,改用树上RMQ。(十万火力嘲讽lyl)
我先略提一下我原本的想法。我们设一下从每个节点往上
2j
2
j
步的编号序列单调递增、单调递减、先增后减、先减后增的最长长度,并用
O(nlog2n)
O
(
n
l
o
g
2
n
)
的时间复杂度预处理出来。然后若询问点u,v之间的路径,我们可用倍增lca求出u,v的lca点z;然后根据倍增lca的经验,我们知道u到z的路径可被划分为至多
log2n
l
o
g
2
n
个区间,v到z同理,于是u到v即可划分为至多
2log2n
2
l
o
g
2
n
个区间。那么我们即可枚举一个区间i,表示那个单峰所在的区间;然后尽量往两边扩延。时间复杂度:
O(nlog2n+q(log2n)2)
O
(
n
l
o
g
2
n
+
q
(
l
o
g
2
n
)
2
)
。
但是还有一种更为简单、且更为快速的做法。
首先,我们可以花费
O(n)
O
(
n
)
的时间预处理出每个点v往上单调递增、单调递减、先增后减、先减后增的最长长度。设它们分别为len[v][0~3]。
然后,对于每个询问u,v,我们同样用倍增lca求出它们的lca点z,然后可分两种情况:1.经过z的答案;2.不经过z的答案。
对于情况1,我们也需要细分为多种情况。我只说一种:在u到z的路径上求能够直接连到z的、先增后减的最长路径,在v到z的路径上求一条能够直接到z的、单调递增的路径(因为是从v到z,如果是从z到v则是单调递减),拼接起来即成一种情况。那么对于要在u到z的路径上求的那一条路径,我们为了使答案最优,肯定让起点的深度尽量大,所以可以二分。但这是
O((log2n)2)
O
(
(
l
o
g
2
n
)
2
)
的。那么其实我们也可以学习倍增的思想,从u开始,从二进制高位依次枚举,设u跳完后的点为f,若以f为起点,len[f][2]依然够不到点z,那么就让u跳。这样即可在
O(log2n)
O
(
l
o
g
2
n
)
的时间内求出符合条件的点。
对于情况2,事实上也很好求。譬如我们要求u到z的路径上,最长的先减后增的路径(不必连到z),设在u到z的路径上求能够直接连到z的、先增后减的最长路径的起点为uir,则我们现在要求的路径的起点的深度<uir的深度便无意义。于是我们就需要在u到uir的路径上的点中求出max(len[][2])。这便转化为了一个树上RMQ问题。
时间复杂度:
O((n+q)log2n)
O
(
(
n
+
q
)
l
o
g
2
n
)
。
Code
#include <cstdio>
#include <algorithm>
using namespace std;
#define N 100001
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
int i,n,fat[N],tot,tov[N],next[N],last[N],anc[N][17],deep[N],len[N][4],ma[N][17][2],q,u,v,la;
inline void insert(int x,int y)
{
tov[++tot]=y;
next[tot]=last[x];
last[x]=tot;
}
void dfs(int x)
{
int i,j,y,t,f;
fo(i,0,1)ma[x][0][i]=len[x][i+2];
fo(i,1,16)
{
if(!(anc[x][i]=anc[f=anc[x][i-1]][i-1]))break;
fo(j,0,1)ma[x][i][j]=max(ma[x][i-1][j],ma[f][i-1][j]);
}
for(i=last[x];i;i=next[i])
{
y=tov[i];
if(y!=fat[x])
{
deep[y]=deep[x]+1;
fo(j,0,3)len[y][j]=1;
t=y<x?0:1;
len[y][t]=len[x][t]+1;
len[y][t+2]=len[x][t+2]+1;
fo(j,2,3)len[y][j]=max(len[y][j],len[y][3-j]);
dfs(y);
}
}
}
int lca(int x,int y)
{
int i,fx,fy;
if(deep[x]!=deep[y])
{
if(deep[x]<deep[y])swap(x,y);
fd(i,16,0)
{
if(deep[fx=anc[x][i]]>=deep[y])x=fx;
if(deep[x]==deep[y])break;
}
}
if(x!=y)
{
fd(i,16,0)
if((fx=anc[x][i])!=(fy=anc[y][i]))
x=fx,y=fy;
x=fat[x];
}
return x;
}
inline bool check(int u,int z,int t)
{
return deep[u]-len[u][t]<deep[z];
}
int get(int u,int z,int t)
{
if(check(u,z,t))return u;
int i,f;
fd(i,16,0)
if((f=anc[u][i])&&deep[f]>=deep[z]&&!check(f,z,t))
u=f;
return fat[u];
}
int access(int x,int y,int z,int t)
{
int i,f,ans=0,S=deep[x]-deep[z];
fd(i,16,0)
if(deep[f=anc[x][i]]>=deep[y])
ans=max(ans,ma[x][i][t]),x=f;
return min(max(ans,ma[x][0][t]),S);
}
int getans()
{
int z=lca(u,v),a,b,ui,ur,uir,uri,vi,vr,vir,vri;
uir=get(u,z,2); uri=get(u,z,3);
ur=get(uir,z,1);ui=get(uri,z,0);
vir=get(v,z,2); vri=get(v,z,3);
vr=get(vir,z,1);vi=get(vri,z,0);
a=max(max(deep[ui]+deep[vir],deep[ur]+deep[vri]),max(deep[uir]+deep[vi],deep[uri]+deep[vr]))-deep[z]*2+1;
b=max(max(access(u,uir,z,0),access(u,uri,z,1)),max(access(v,vir,z,0),access(v,vri,z,1)));
return max(a,b);
}
int main()
{
scanf("%d",&n);
fo(i,2,n)
{
scanf("%d",&fat[i]);
insert(anc[i][0]=fat[i],i);
}
deep[1]=1;
fo(i,0,3)len[1][i]=1;
dfs(1);
scanf("%d",&q);
fo(i,1,q)
{
scanf("%d%d",&u,&v);
u^=la;v^=la;
printf("%d\n",la=getans());
}
}