【树上算法】最近公共祖先(LCA)

最近公共祖先

upd on 2023.04.26:不少的用户反映代码有问题……怎么说呢……这算是一篇远古文章了吧(至少对于我说)所以我不太想改了,相信你们看了这篇文章之后能自己写出正确的代码吧~

前言

最近公共祖先是在树上的一个算法。这里就提到了树。树想必大家都再熟悉不过了,在计算机中,我们可以把树看成一种特殊的图。
这种图就叫有向无环联通图。顾名思义,没有环,有向图, n n n个节点 n − 1 n-1 n1条边。

一、什么是最近公共祖先

最近公共祖先(Lowest Common Ancestors,简称LCA)是一种典型的树上算法。
最近公共祖先是针对树上两个节点u和v的。这里就涉及到了祖先这个概念。a是b的祖先就是说a是b的父亲的父亲的父亲……(也包括父亲)。
现在我们就要求一个节点,既是u的祖先,也是v的祖先,离他们的距离要近一些。

二、LCA的计算方法

1.暴力枚举(不用说都知道)

时间复杂度: O ( n 2 ) O(n^2) O(n2)
直接超时!

2.简单递归

我们经过简单的思考,就能得出一个结论:
如果 u = v u=v u=v, 那么 L C A ( u , v ) = u ( 或 v ) LCA(u,v)=u(或v) LCA(u,v)=u(v)
把我们每次递归的时候就这样:(伪代码)

if u = v  return u
else
    u = u的父亲
    v = v的父亲

那我们就会发现一个问题:
u到根节点的距离为5,v到根节点的距离为4,那么他们不管怎么跳,始终跳不到一起。所以我们可以让他们跳到同一个深度
所以我们就先用一个dfs求出他们每个节点的深度,再让u和v跳到同一个深度,再执行刚才的代码。
求深度的代码:(伪代码)

def dfs: node,k
    dep[node]=k
    for son of node 1...n:
        dfs->son,k+1

根据刚才的思路,我们可以得到这些核心代码:

1.批量求节点的深度。

int dep[MAXN];
void dfs(int node, int k){
	dep[node] = k;
	for(int i = 0; i < G[node].size(); i++){
		dfs(G[node][i], k + 1);
	}
}

2.将两个节点调到同一个深度

while(dep[u] > dep[v]){
	u = fa[u];
}
while(dep[v] > dep[u]){
	v = fa[v];
}

3.求最近公共祖先

int lca(int u, int v){
	if(u == v)return u;
	return lca(fa[u], fa[v]);
}

可以说没有什么难度。(全部代码在文章末尾。)
时间复杂度: O ( n ) O(n) O(n),直接降了一次幂。

3.二进制优化

我们仔细看上面的代码,发现有些地方可以优化一下。
例如:“将两个节点调到同一个深度”这一块代码,我们发现:在调整的过程中,一步一步的向上“跳”有点慢,所以可以像跳棋一样来一个“大跳”。
对于一个节点,我们可以尝试让他们跳 2 n 2^n 2n步。
我们一步一步来。

step 1:跳到哪里去?

我们用 f ( i , j ) f(i,j) f(i,j)表示i节点向上跳 2 j 2^j 2j步后所到达的节点。
递推实现,递推式:
f [ i ] [ 0 ] = f a [ i ] f[i][0] = fa[i] f[i][0]=fa[i]
f [ i ] [ j ] = f [ f [ i ] [ j − 1 ] ] [ j − 1 ] f[i][j] = f[f[i][j - 1]][j - 1] f[i][j]=f[f[i][j1]][j1],即先向上跳2j-1步,再向上跳2j-1步。

step2:跳多少步?

二进制拆分。已知dep[u] = 64, dep[v] = 24,那我们可以这样:

1.尝试跳26步,即64步。dep[u] - 64 < dep[v],不可以。
2.尝试跳25步,即32步。dep[u] - 32 > dep[v],可以。dep[u] -= 32 dep[u] = 32
3.尝试跳24步,即16步。dep[u] - 16 < dep[v],不可以。
4.尝试跳23步,即8步。dep[u] - 8 == dep[v],算法结束。

就这样,本来要跳(64-24)步,现在只要跳4步。
另外,对于2n,我们用1 << n表示即可。

代码:
1.大跳前进:

for(int k = log(2, dep[u]); dep[u] != dep[v]; k--){
	if(dep[u] - (1 << k) >= dep[v]){
	    dep[u] -= (1 << k);
	    u = f[u][k];
	}
}

2.求出f数组:

for(int node = 1; node <= n; i++){
	f[node][0] = fa[node];
	for(int i = 1; i <= log(2, n); i++){
		f[node][i] = f[f[node][i - 1]][i - 1];
	}
}

有点思考含量在里面。

时间复杂度: O ( l o g   n ) O(log~n) O(log n)

三、完整代码

你所期待的完整代码终于来了。提供两个版本:
1. O ( n ) O(n) O(n)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int INF = 0x3f3f3f;
const int MAXN = 1e4 + 5;
int fa[MAXN], dep[MAXN];
vector<int> G[MAXN];
void dfs(int node, int k){
	dep[node] = k;
	for(int i = 0; i < G[node].size(); i++){
		dfs(G[node][i], k + 1);
	}
}
int lca(int u, int v){
	if(u == v)return u;
	return lca(fa[u], fa[v]);
}
signed main(){
	memset(fa, -1, sizeof fa);
	int n, root;
	cin >> n;
	for(int i = 1; i < n; i++){
		int u, v;
		cin >> u >> v;
		fa[v] = u;
		G[u].push_back(v);
	}
	for(int i = 1; i <= n / 2 + 1; i++){
		if(fa[i] == -1){
			root = i;
			break;
		}
		if(fa[n - i] == -1){
			root = n - i;
			break;
		}
	}
	dfs(root, 0);
	int u, v;
	cin >> u >> v;
	while(dep[u] > dep[v]){
		u = fa[u];
	}
	while(dep[v] > dep[u]){
		v = fa[v];
	}
	cout << lca(u, v) << endl;
	return 0;
}

2. O ( l o g   n ) O(log~n) O(log n)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int INF = 0x3f3f3f;
const int MAXN = 1e4 + 5;
const int MAXLOG = log(2, MAXN);
int fa[MAXN], dep[MAXN], f[MAXN][MAXLOG];
vector<int> G[MAXN];
void dfs(int node, int k){
	dep[node] = k;
	for(int i = 0; i < G[node].size(); i++){
		dfs(G[node][i], k + 1);
	}
}
void getf(){
	for(int node = 1; node <= n; i++){
		f[node][0] = fa[node];
		for(int i = 1; i <= log(2, n); i++){
			f[node][i] = f[f[node][i - 1]][i - 1];
		}
	}
}
int lca(int u, int v){
	if(u == v)return u;
	return lca(fa[u], fa[v]);
}
signed main(){
	memset(fa, -1, sizeof fa);
	getf();
	int n, root;
	cin >> n;
	for(int i = 1; i < n; i++){
		int u, v;
		cin >> u >> v;
		fa[v] = u;
		G[u].push_back(v);
	}
	for(int i = 1; i <= n / 2 + 1; i++){
		if(fa[i] == -1){
			root = i;
			break;
		}
		if(fa[n - i] == -1){
			root = n - i;
			break;
		}
	}
	dfs(root, 0);
	int u, v;
	cin >> u >> v;
	for(int k = log(2, dep[u]); dep[u] >= dep[v]; k--){
	    if(dep[u] - (1 << k) >= dep[v]){
	    	dep[u] -= (1 << k);
	    	u = f[u][k];
	    }
	}
	for(int k = log(2, dep[v]); dep[v] >= dep[u]; k--){
	    if(dep[v] - (1 << k) >= dep[u]){
	    	dep[v] -= (1 << k);
	    	v = f[v][k];
	    }
	}
	cout << lca(u, v) << endl;
	return 0;
}
  • 6
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值