最近公共祖先
upd on 2023.04.26:不少的用户反映代码有问题……怎么说呢……这算是一篇远古文章了吧(至少对于我说)所以我不太想改了,相信你们看了这篇文章之后能自己写出正确的代码吧~
前言
最近公共祖先是在树上的一个算法。这里就提到了树。树想必大家都再熟悉不过了,在计算机中,我们可以把树看成一种特殊的图。
这种图就叫有向无环联通图。顾名思义,没有环,有向图,
n
n
n个节点
n
−
1
n-1
n−1条边。
一、什么是最近公共祖先
最近公共祖先(Lowest Common Ancestors,简称LCA)是一种典型的树上算法。
最近公共祖先是针对树上两个节点u和v的。这里就涉及到了祖先这个概念。a是b的祖先就是说a是b的父亲的父亲的父亲……(也包括父亲)。
现在我们就要求一个节点,既是u的祖先,也是v的祖先,离他们的距离要近一些。
二、LCA的计算方法
1.暴力枚举(不用说都知道)
时间复杂度:
O
(
n
2
)
O(n^2)
O(n2)
直接超时!
2.简单递归
我们经过简单的思考,就能得出一个结论:
如果
u
=
v
u=v
u=v, 那么
L
C
A
(
u
,
v
)
=
u
(
或
v
)
LCA(u,v)=u(或v)
LCA(u,v)=u(或v)
把我们每次递归的时候就这样:(伪代码)
if u = v return u
else
u = u的父亲
v = v的父亲
那我们就会发现一个问题:
u到根节点的距离为5,v到根节点的距离为4,那么他们不管怎么跳,始终跳不到一起。所以我们可以让他们跳到同一个深度。
所以我们就先用一个dfs求出他们每个节点的深度,再让u和v跳到同一个深度,再执行刚才的代码。
求深度的代码:(伪代码)
def dfs: node,k
dep[node]=k
for son of node 1...n:
dfs->son,k+1
根据刚才的思路,我们可以得到这些核心代码:
1.批量求节点的深度。
int dep[MAXN];
void dfs(int node, int k){
dep[node] = k;
for(int i = 0; i < G[node].size(); i++){
dfs(G[node][i], k + 1);
}
}
2.将两个节点调到同一个深度
while(dep[u] > dep[v]){
u = fa[u];
}
while(dep[v] > dep[u]){
v = fa[v];
}
3.求最近公共祖先
int lca(int u, int v){
if(u == v)return u;
return lca(fa[u], fa[v]);
}
可以说没有什么难度。(全部代码在文章末尾。)
时间复杂度:
O
(
n
)
O(n)
O(n),直接降了一次幂。
3.二进制优化
我们仔细看上面的代码,发现有些地方可以优化一下。
例如:“将两个节点调到同一个深度”这一块代码,我们发现:在调整的过程中,一步一步的向上“跳”有点慢,所以可以像跳棋一样来一个“大跳”。
对于一个节点,我们可以尝试让他们跳
2
n
2^n
2n步。
我们一步一步来。
step 1:跳到哪里去?
我们用
f
(
i
,
j
)
f(i,j)
f(i,j)表示i节点向上跳
2
j
2^j
2j步后所到达的节点。
递推实现,递推式:
f
[
i
]
[
0
]
=
f
a
[
i
]
f[i][0] = fa[i]
f[i][0]=fa[i]
f
[
i
]
[
j
]
=
f
[
f
[
i
]
[
j
−
1
]
]
[
j
−
1
]
f[i][j] = f[f[i][j - 1]][j - 1]
f[i][j]=f[f[i][j−1]][j−1],即先向上跳2j-1步,再向上跳2j-1步。
step2:跳多少步?
二进制拆分。已知dep[u] = 64, dep[v] = 24
,那我们可以这样:
1.尝试跳26步,即64步。dep[u] - 64 < dep[v]
,不可以。
2.尝试跳25步,即32步。dep[u] - 32 > dep[v]
,可以。dep[u] -= 32
dep[u] = 32
3.尝试跳24步,即16步。dep[u] - 16 < dep[v]
,不可以。
4.尝试跳23步,即8步。dep[u] - 8 == dep[v]
,算法结束。
就这样,本来要跳(64-24)步,现在只要跳4步。
另外,对于2n,我们用1 << n
表示即可。
代码:
1.大跳前进:
for(int k = log(2, dep[u]); dep[u] != dep[v]; k--){
if(dep[u] - (1 << k) >= dep[v]){
dep[u] -= (1 << k);
u = f[u][k];
}
}
2.求出f数组:
for(int node = 1; node <= n; i++){
f[node][0] = fa[node];
for(int i = 1; i <= log(2, n); i++){
f[node][i] = f[f[node][i - 1]][i - 1];
}
}
有点思考含量在里面。
时间复杂度: O ( l o g n ) O(log~n) O(log n)
三、完整代码
你所期待的完整代码终于来了。提供两个版本:
1.
O
(
n
)
O(n)
O(n)版
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int INF = 0x3f3f3f;
const int MAXN = 1e4 + 5;
int fa[MAXN], dep[MAXN];
vector<int> G[MAXN];
void dfs(int node, int k){
dep[node] = k;
for(int i = 0; i < G[node].size(); i++){
dfs(G[node][i], k + 1);
}
}
int lca(int u, int v){
if(u == v)return u;
return lca(fa[u], fa[v]);
}
signed main(){
memset(fa, -1, sizeof fa);
int n, root;
cin >> n;
for(int i = 1; i < n; i++){
int u, v;
cin >> u >> v;
fa[v] = u;
G[u].push_back(v);
}
for(int i = 1; i <= n / 2 + 1; i++){
if(fa[i] == -1){
root = i;
break;
}
if(fa[n - i] == -1){
root = n - i;
break;
}
}
dfs(root, 0);
int u, v;
cin >> u >> v;
while(dep[u] > dep[v]){
u = fa[u];
}
while(dep[v] > dep[u]){
v = fa[v];
}
cout << lca(u, v) << endl;
return 0;
}
2. O ( l o g n ) O(log~n) O(log n)版
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int INF = 0x3f3f3f;
const int MAXN = 1e4 + 5;
const int MAXLOG = log(2, MAXN);
int fa[MAXN], dep[MAXN], f[MAXN][MAXLOG];
vector<int> G[MAXN];
void dfs(int node, int k){
dep[node] = k;
for(int i = 0; i < G[node].size(); i++){
dfs(G[node][i], k + 1);
}
}
void getf(){
for(int node = 1; node <= n; i++){
f[node][0] = fa[node];
for(int i = 1; i <= log(2, n); i++){
f[node][i] = f[f[node][i - 1]][i - 1];
}
}
}
int lca(int u, int v){
if(u == v)return u;
return lca(fa[u], fa[v]);
}
signed main(){
memset(fa, -1, sizeof fa);
getf();
int n, root;
cin >> n;
for(int i = 1; i < n; i++){
int u, v;
cin >> u >> v;
fa[v] = u;
G[u].push_back(v);
}
for(int i = 1; i <= n / 2 + 1; i++){
if(fa[i] == -1){
root = i;
break;
}
if(fa[n - i] == -1){
root = n - i;
break;
}
}
dfs(root, 0);
int u, v;
cin >> u >> v;
for(int k = log(2, dep[u]); dep[u] >= dep[v]; k--){
if(dep[u] - (1 << k) >= dep[v]){
dep[u] -= (1 << k);
u = f[u][k];
}
}
for(int k = log(2, dep[v]); dep[v] >= dep[u]; k--){
if(dep[v] - (1 << k) >= dep[u]){
dep[v] -= (1 << k);
v = f[v][k];
}
}
cout << lca(u, v) << endl;
return 0;
}