最近公共祖先(LCA)
顾名思义,就是两个结点的最近公共父结点;
这里运用了倍增的思路,当你想要一个个的往上递增时,复杂度很高,可以把要递增的数量换成二进制表示(因为任意一个数都可以由二进制数相加得到);
比如11,可以由 2^3=8,2 ^1=2, 2 ^0=1,相加而成;所以11这个数可以先加8,在加2,在加1得到;
这里有两个预处理,一个是得到每个结点所在的层数,一个是得到每个结点往上走2的 j 次方次得到的结点;
这两个操作可以放在一个dfs函数里面,应该非常好理解:
fa[sn][i]=fa[fa[sn][i-1]][i-1];可以手动模拟,大致意思为2^n=2 ^n-1 *+2 ^ n-1;
void dfs(int sn,int ft){
dep[sn]=dep[ft]+1,fa[sn][0]=ft;
for(int i=1;i<=lg[dep[sn]];i++) fa[sn][i]=fa[fa[sn][i-1]][i-1];
for(int i=head[sn];~i;i=edge[i].nex){
if(edge[i].to!=ft) dfs(edge[i].to,sn);
}
}
这里还学到一种预处理log_2(i)+1的方法:
for(int i=1;i<=n;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);//预处理(log_2(i))+1的值
这个东西就是求一个数要用二进制数组成的最大次数:
比如11,它的lg[11]=4,我们在调用lg[11]的时候还要减去1,也就是说要组成11,最多只要3次;
然后核心的 lca 代码主要的思路就是:先使 x 和 y 在同一层,然后在一起往上走,找到最顶部的两个结点 x 和 y ,结点 x 和 y 的父节点相等,那么 x 的父节点就是答案;
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int k=lg[dep[x]]-1;k>=0;k--){
if(fa[x][k]!=fa[y][k]){
x=fa[x][k],y=fa[y][k];
}
}
return fa[x][0];
}
全部题目来做这道题:
【模板】最近公共祖先(LCA)
全部代码:
#include<bits/stdc++.h>
#define LL long long
#define pa pair<int,int>
#define ls k<<1
#define rs k<<1|1
#define inf 0x3f3f3f3f
using namespace std;
const int N=500100;
const int M=1000100;
const LL mod=100000000;
int n,m,s,head[N],cnt,lg[N],dep[N],fa[N][40];
struct Node{
int to,nex;
}edge[M];
void add(int p,int q){
edge[cnt].to=q;
edge[cnt].nex=head[p];
head[p]=cnt++;
}
void dfs(int sn,int ft){
dep[sn]=dep[ft]+1,fa[sn][0]=ft;
for(int i=1;i<=lg[dep[sn]];i++) fa[sn][i]=fa[fa[sn][i-1]][i-1];
for(int i=head[sn];~i;i=edge[i].nex){
if(edge[i].to!=ft) dfs(edge[i].to,sn);
}
}
int lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) x=fa[x][lg[dep[x]-dep[y]]-1];
if(x==y) return x;
for(int k=lg[dep[x]]-1;k>=0;k--){
if(fa[x][k]!=fa[y][k]){
x=fa[x][k],y=fa[y][k];
}
}
return fa[x][0];
}
int main(){
// ios::sync_with_stdio(false);
memset(head,-1,sizeof(head));
cin>>n>>m>>s;
for(int i=1;i<n;i++){
int p,q;
scanf("%d%d",&p,&q);
add(p,q),add(q,p);
}
for(int i=1;i<=n;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);//预处理(log_2(i))+1的值
dfs(s,0);
while(m--){
int a,b;
scanf("%d%d",&a,&b);
printf("%d\n",lca(a,b));
}
return 0;
}