题目链接:http://codeforces.com/contest/832/problem/D
题目大意:n个点,n-1条边,然后q次查询,每次查询给你三个数a,b,c,从这三个数中选择两个数
作为起点,剩下一个数作为终点,问你两个起点到终点的路线中能够重合的点最多有多少个。
题目思路:乍一看感觉像是树形DP 或者 树链刨分之类的随便搞搞,但是仔细想想,求经过的点数和
求树上两点之间的距离不是一样的吗,不过最后加一的区别而已。树上两点距离的话最近公共祖先随便
搞搞就可以了,但是有q次询问,因此需要在线最近公共祖先,呢就倍增喽(新学的。。。。)
什么是最近公共祖先我就不多说了,这里就简要说一下倍增吧,为什么要倍增了,很简单,之前的离线比较
暴力,倍增算法的区别是保存当前点到所有距离的父亲有谁,也就是fa数组,但是为什么是倍增呢?
定义fa[ i ] [ j ]:表示点 i 向上走2^j的距离的人父亲是谁,既然是幂次的上升,呢就是倍增了(我是这样理解的)
呢递推关系就很好理解了: fa[ i ][ j ]=fa[ fa[ i ][ j-1 ] ][ j-1 ];
剩下的就是怎么找最近公共祖先了,倍增方法其实比Tarjan算法好理解,我就稍微说一下核心的地方吧
首先就是定义一个根,dfs一遍求出每一点的深度和父亲,代码如下:
void dfs(int u,int p)
{
int i,j;
dep[u]=dep[p]+1;
fa[u][0]=p;
for(i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==p)
continue;
dfs(v,u);
}
}
然后就是LCA了,其实很好弄明白这个过程,你想一想,两个点的公共祖先一定是他们的最近‘父亲’了,
呢就向上跳就可以了,怎么跳呢,别忘了之前定义的fa数组,呢肯定按2的幂次跳喽,首先找两者中深度最
深的,向上一直跳到和另一点同一层,这很好理解吧,不跳到同一层怎么一起找父亲,剩下的就是向上找共同的
父亲了,找到后返回即可,代码如下:
int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
int i;
for(i=19;i>=0;i--)
if(dep[x]-(1<<i)>=dep[y])
x=fa[x][i];
if(x==y)
return x;
for(i=19;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
倍增算法就讲到这里了,我讲的不是很清楚,只是把自己理解到的说了出来。。。。
本题代码:
#include<map>
#include<stack>
#include<queue>
#include<vector>
#include<math.h>
#include<stdio.h>
#include<iostream>
#include<string.h>
#include<stdlib.h>
#include<algorithm>
using namespace std;
typedef long long ll;
#define inf 1000000000
#define mod 1000000007
#define maxn 210000
#define lowbit(x) (x&-x)
#define eps 1e-10
vector<int>e[maxn];
int dep[maxn],fa[maxn][50];
void dfs(int u,int p)
{
int i,j;
dep[u]=dep[p]+1;
fa[u][0]=p;
for(i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(i=0;i<e[u].size();i++)
{
int v=e[u][i];
if(v==p)
continue;
dfs(v,u);
}
}
int lca(int x,int y)
{
if(dep[x]<dep[y])
swap(x,y);
int i;
for(i=19;i>=0;i--)
if(dep[x]-(1<<i)>=dep[y])
x=fa[x][i];
if(x==y)
return x;
for(i=19;i>=0;i--)
if(fa[x][i]!=fa[y][i])
x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int route(int x,int y)
{
int tmp=lca(x,y);
return dep[x]+dep[y]-2*dep[tmp];
}
int findanswer(int a,int b,int c)
{
return (route(a,c)+route(b,c)-route(a,b))/2+1;
}
void work(int a,int b,int c)
{
int ans=0,tmp[10];
tmp[1]=findanswer(a,b,c);
tmp[2]=findanswer(a,c,b);
tmp[3]=findanswer(b,a,c);
tmp[4]=findanswer(b,c,a);
tmp[5]=findanswer(c,a,b);
tmp[6]=findanswer(c,b,a);
for(int i=1;i<=6;i++)
ans=max(ans,tmp[i]);
printf("%d\n",ans);
return;
}
int main(void)
{
int q,i,j,x,y,z,n;
scanf("%d%d",&n,&q);
for(i=2;i<=n;i++)
{
scanf("%d",&x);
e[x].push_back(i);
e[i].push_back(x);
}
dfs(1,0);
while(q--)
{
scanf("%d%d%d",&x,&y,&z);
work(x,y,z);
}
return 0;
}