数据结构中的树,在计算机科学中是非常重要的,例如我们来看看下面这棵树:
在图中我们对每个节点都有编号了。 8号节点是这棵树的根。我们定义,一个子节点向它的根节点的路径上,任意一个节点都称为它的祖先。例如, 4号节点是16号节点的祖先。而10号节点同样也是16号的祖先。事实上,16号的祖先有8,4,10,16共四个。另外8, 4, 6,7都是7号节点的祖先,所以7号和16号的公共祖先是4和8号,而在这两个里面,4号是距离7和16最近的一个,所以我们称7号和16号的最近公共祖先是4号。
再例如,2和3的最近公共祖先是10,再例如6和13的是8。
现在你需要编写一个程序,在一棵树中找出指定两个节点的最近公共祖先
Input
第一行输入T表示有T组数据。每组第一行是N表示这棵树有多少个节点,其中 2<=N<=10,000。 节点用正整数1, 2,…, N表示。 接下来的 N -1 行表示这棵树的边,每行两个数,都是节点编号,前一个是后一个的父节点。最后一行是要查询的两个节点,计算出这两个节点的最近公共祖先
Output
对于每组测试输出一行,输出它们的最近公共祖先的编号。
Sample Input
2
16
1 14
8 5
10 16
5 9
4 6
8 4
4 10
1 13
6 15
10 11
6 7
10 2
16 3
8 1
16 12
16 7
5
2 3
3 4
3 1
1 5
3 5
Sample Output
4
3
解题思路:这个题就直接套LCA的板子就可以啦,我们先来讲一下LCA叭,
求LCA有三种方法(我知道的,可能有其他的,我太菜了,不知道),分别是并查集,RMQ,和倍增,这里我们来讲一下倍增,我们先介绍一下所需的变量。
lg[x]:表示2lg[x]是大于x的最小的数。
depth[x]:表示结点x在树中的深度。
f[u][i]:表示结点u的第2i位祖先。
倍增法的过程是:
1:先预处理出每个点的2的指数倍的祖先结点。
2:再寻找的时候,先让两个结点的深度相同,处理方法为找到较深的结点的与另一个结点相同深度的祖父结点。
3:此时如果是两个结点相同,那么答案就是本身,否则执行步骤4。
4:寻找两个结点中不相等的祖先结点,然后去寻找祖先结点的祖先结点,一直到找不到为止。
注意:在执行步骤4时,需要从最大的2的指数倍开始找,原因是:
如果从小开始找,比如LCA是结点的第17辈的祖先,
那么我们找就是 1+2+4+8=15,这个时候还差2才可以得到第17位,这里就要后悔,要返回前面的操作,而我们从大往小找的话:16+1=17,就只需要两步就得到答案啦。
代码:
#pragma GCC optimize(2)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <string>
#include <vector>
#include <set>
#include <map>
#include <stack>
#include <bitset>
#include <queue>
//#include <random>
#include <time.h>
using namespace std;
#define int long long
#define ull unsigned long long
#define ls root<<1
#define rs root<<1|1
const int maxn = 1e4 + 7;
//std::mt19937 rnd(time(NULL));
struct edge
{
int v,next;
}e[maxn];
int cnt,head[maxn],f[maxn][20],depth[maxn],lg[maxn];
void add(int a,int b)
{
e[++cnt]=edge{b,head[a]};
head[a]=cnt;
}
void dfs(int u,int fa)
{
f[u][0]=fa;
depth[u]=depth[fa]+1;
for(int i=1;(1<<i)<=depth[u];i++){
f[u][i]=f[f[u][i-1]][i-1];
}
for(int i=head[u];i;i=e[i].next){
int v=e[i].v;
dfs(v,u);
}
}
int getlca(int a,int b)
{
if(depth[a]<depth[b])swap(a,b);
while(depth[a]>depth[b]){
a=f[a][lg[depth[a]-depth[b]]-1];
}
if(a==b)return a;
for(int i=lg[depth[a]]-1;i>=0;i--){
if(f[a][i]!=f[b][i]){
a=f[a][i],b=f[b][i];
}
}
return f[a][0];
}
signed main()
{
int t;
scanf("%lld",&t);
for(int i=1;i<maxn;i++){
lg[i]=lg[i-1]+((1<<lg[i-1])==i);
}
while(t--){
int n;
scanf("%lld",&n);
memset(head,0,sizeof head);
cnt=0;
int root=n;
for(int i=1;i<n;i++){
int a,b;
root^=i;
scanf("%lld%lld",&a,&b);
root^=b;
add(a,b);
}
dfs(root,0);
int u,v;
scanf("%lld%lld",&u,&v);
printf("%lld\n",getlca(u,v));
}
}