题意:
在一棵树中,每次选择一个区间[l,r]最多删除一个点,使得这个区间内所有点的lca的深度最大。
思路:
首先有一个点,就是一颗树中一堆点的LCA其实就是这堆点DFS序最小和最大的两个点的LCA,线段树维护区间max1,max2,min1,min2,然后倍增求LCA
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <string>
#include <queue>
#include <vector>
using namespace std;
int n,q;
vector<int> e[100005];
int dep[100005];
int tid[100005];
int atid[100005];
int cnt;
int fa[100005][20];
int inf=1e9+7;
void dfs(int u,int d)
{
cnt++;
tid[u]=cnt;atid[cnt]=u;dep[u]=d;
int len=e[u].size();
for(int i=0;i<len;i++)
{
int v=e[u][i];
dfs(v,d+1);
}
}
struct node
{
int l,r;
int max1,max2;
int min1,min2;
}tr[100005*5];
node up(node rt,node a,node b)
{
node ntr=rt;
if(a.max1>b.max1)
{
ntr.max1=a.max1;
ntr.max2=max(a.max2,b.max1);
}
else
{
ntr.max1=b.max1;
ntr.max2=max(a.max1,b.max2);
}
if(a.min1<b.min1)
{
ntr.min1=a.min1;
ntr.min2=min(a.min2,b.min1);
}
else
{
ntr.min1=b.min1;
ntr.min2=min(a.min1,b.min2);
}
return ntr;
}
void build(int i,int l,int r)
{
tr[i].l=l;tr[i].r=r;
if(l==r)
{
tr[i].max1=tid[l];tr[i].max2=0;
tr[i].min1=tid[l];tr[i].min2=inf;
} else
{
int mid=(l+r)/2;
build(i*2,l,mid);build(i*2+1,mid+1,r);
tr[i]=up(tr[i],tr[i*2],tr[i*2+1]);
}
}
node get(int i,int l,int r,int ql,int qr)
{
if(l==ql && r==qr)
return tr[i];
else
{
int mid=(l+r)/2;
if(qr<=mid)
return get(i*2,l,mid,ql,qr);
else if(ql>mid)
return get(i*2+1,mid+1,r,ql,qr);
else
{
node ans;
return up(ans,get(i*2,l,mid,ql,mid),get(i*2+1,mid+1,r,mid+1,qr));
}
}
}
int gf(int x,int y)
{
if(x==y)
return x;
else
{
if(dep[x]<dep[y])
swap(x,y);
for(int i=19;i>=0;i--)
{
if(dep[fa[x][i]]>=dep[y])
{
x=fa[x][i];
}
}
if(x==y)
return x;
for(int i=19;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
{
x=fa[x][i];y=fa[y][i];
}
}
return fa[x][0];
}
}
int main() {
while(~scanf("%d%d",&n,&q))
{
memset(fa,0,sizeof(fa));
for(int i=0;i<=100001;i++)
{
e[i].clear();
}
for(int i=2;i<=n;i++)
{
int tp;scanf("%d",&tp);
e[tp].push_back(i);
fa[i][0]=tp;
}
for(int i=1;i<20;i++)
{
for(int j=1;j<=n;j++)
{
fa[j][i]=fa[fa[j][i-1]][i-1];
}
}
cnt=0;
dfs(1,0);
build(1,1,n);
while(q--)
{
int l,r;
scanf("%d%d",&l,&r);
node ntp=get(1,1,n,l,r);
int x1=atid[ntp.max1],y1=atid[ntp.min2],ans1=atid[ntp.min1];int anss1=dep[gf(x1,y1)];
int x2=atid[ntp.max2],y2=atid[ntp.min1],ans2=atid[ntp.max1];int anss2=dep[gf(x2,y2)];
if(anss1>anss2)
printf("%d %d\n",ans1,anss1);
else
printf("%d %d\n",ans2,anss2);
}
}
return 0;
}