【模板】最近公共祖先(LCA)
前言
相信大家对LCA一定不陌生,LCA的全称为最近公共祖先(lowest common ancestor)简称为LCA。
LCA可以算是一个经典的树上算法了。
模板题目
如果大家想做LCA的模板题,大家可以上洛谷的P3379【模板】最近公共祖先(LCA)来做题。
不过,我得提醒一下大家,在我做的时候,我将 M a x n Maxn Maxn以及 M a x m Maxm Maxm开到了 500000 500000 500000,不过直到我把 M a x n Maxn Maxn和 M a x m Maxm Maxm开到 1000000 1000000 1000000的时候,他才没有RE。
这怕不是又是洛谷的feature \bold{\Large\xcancel\text{这怕不是又是洛谷的feature}} 这怕不是又是洛谷的feature
LCA
LCA的定义
给定一棵有根树,若结点 z z z既是结点 x x x的祖先,又是结点 y y y 的祖先,则称 z z z 为 x , y x,y x,y 的公共祖先。可以发现, x , y x,y x,y 的所有公共祖先是在从根结点开始的某条链上,其中最深的那个公共祖先(也就是离 x , y x,y x,y 最近的)被称为最近公共祖先,这就叫LCA。
不明白什么叫祖先?
看下面这张图:
在这张图中,我们称点 A A A为点 B , C , D , E , F B,C,D,E,F B,C,D,E,F的祖先, B B B为 E , F E,F E,F的祖先。
LCA的初步实现
最普通的暴力
在知道了LCA是什么之后,大家一定会想到这样的实现:
求 LCA 最容易想到的方案是:
- 先从 x x x往上走到根,沿途会经过 x x x所有的祖先,把它们用一个数组标记。
- 再从 y y y往上走到根,沿途会经过 y y y所有的祖先,遇到的第一个被标记的点就是 x , y x,y x,y 的最近公共祖先。
代码如下,时间复杂度为 O ( n ) \mathcal{O}(n) O(n)。
int fa[MAX_N], vis[MAX_N]; // fa 数组保存每个结点的父节点,vis 数组用来标记
int LCA(int x, int y) {
memset(vis, 0, sizeof vis);
while (x != 0) {
vis[x] = 1;
x = fa[x];
}
while (vis[y] == 0) {
y = fa[y];
}
return y;
}
优化++
但是,似乎每次都标记也太浪费了。
我们可以先让两个点走到同一深度,然后一起往上走。
我们需要用到 d f s dfs dfs来实现这个走到同一深度的方法。
int d[MAX_N], fa[MAX_N]; // d 数组保存每个结点的深度
void dfs(int u) {
d[u] = d[fa[u]] + 1;
for (int i = p[u]; i != -1; i = e[i].next) {
int v = e[i].v;
if (v != fa[u]) {
fa[v] = u;
dfs(v);
}
}
}
int lca(int x, int y) {
if (d[x] < d[y]) {
swap(x, y); // 让 x 为深度更深的那个点
}
while (d[x] > d[y]) {
x = fa[x]; // 让 x 和 y 处于同一深度
}
while (x != y) {
x=fa[x];
y=fa[y];
}
return x;
}
但这种做法的时间复杂度依然为 O ( n ) \mathcal{O}(n) O(n)。
这个算法要优化的更好的瓶颈在于通过 f a fa fa数组往上走,每次走一步实在太慢了。那么有没有方法可以一次性走一大步呢?
LCA的进阶实现
进阶的思想
答案是采用二进制的方法往上跳。
就比如这段代码:
while (d[x] > d[y]) {
x = fa[x]; // 让 x 和 y 处于同一深度
}
就可以改为
int K = 0;
while ((1 << (K + 1)) <= d[x]) {
K++;
}
for (int i = K; i >= 0; i--) {
//如果 x 的 2^i 祖先深度大于等于 y 的深度,x 就往上跳到 2^i 祖先
}
其中 K K K为最大的整数满足 2 K ≤ d [ x ] 2^K \le d[x] 2K≤d[x]。
我们让 x x x每次尝试跳 2 i 2^i 2i步, i i i从 K K K开始从大到小枚举。如果跳跃后深度依然不小于 y y y,就选择跳跃。
换种角度思考,设 t = d [ x ] − d [ y ] t=d[x]-d[y] t=d[x]−d[y],那么 t t t的二进制表示中 1 1 1的位置就是 x x x要跳的那步。相当于用若干个不同的 2 2 2的幂次来凑出这个 t t t,我们肯定会选择从大到小凑,并且最终方案肯定是唯一的。
同理,当 x , y x,y x,y 到达同一深度后,两个点继续同时往上跳的步骤也可以用这种二进制尝试跳跃的方法。 如果能在 O ( 1 ) \mathcal{O}(1) O(1) 时间内得到个结点的 2 2 2 的幂次辈祖先,那么这种方法计算 L C A ( x , y ) LCA(x,y) LCA(x,y)的时间复杂度就为 O ( log n ) \mathcal{O}(\log n) O(logn)。 现在的问题变为如何预处理每个结点的 2 2 2的幂次辈祖先?
用 d p dp dp+ d f s dfs dfs来搞 2 i 2^i 2i的预处理
我们解决的方法是动态规划。
我们定义f[u][j]
表示u
节点的
2
j
2^j
2j辈祖先(如果没有则为
0
0
0)。那么f[u][0]
就是
u
u
u的父节点。我们在
D
F
S
DFS
DFS求深度的时候同时维护一下下即可。
void dfs(int u) {
d[u]=d[f[u][0]]+1;
for (int i=p[u];i!=-1;i=e[i].next) {
int v=e[i].v;
if (v == f[u][0]) {
continue;
}
f[v][0]=u;
dfs(v);
}
}
然后通过递推计算所有结点的 2 2 2的幂次辈祖先
for(int j=1;(1<<j)<=n;j++){
for(int i=1;i<=n;i++){
f[i][j]=f[f[i][j-1]][j-1];
}
}
转移过程也很好理解, i i i 的 2 j 2^j 2j辈祖先等于 i i i的 2 j − 1 2^{j-1} 2j−1辈祖先的 2 j − 1 2^{j-1} 2j−1辈祖先。
这步预处理的时间是复杂度为 O ( n log n ) \mathcal{O}(n\log n) O(nlogn)
正版LCA
我们现在就可以用上次说的LCA来完成了。
int lca(int x,int y){
if(d[x]<d[y]){
swap(x,y);
}
int K=0;
while((1<<(K+1))<=d[x]){
K++;
}
for(int j=K;j>=0;j--){
if(d[f[x][j]]>=d[y]){
x=f[x][j];
}
}
if(x==y){
return x;
}
for(int j=K;j>=0;j--){
if(f[x][j]!=f[y][j]){
x=f[x][j];
y=f[y][j];
}
}
return f[x][0];
}
我们再回顾一下,首先通过交换确保 x x x的深度更深,然后找到最大的 K K K 满足 2 K ≤ d [ x ] 2^K\le d[x] 2K≤d[x],作为二进制尝试跳跃的上界。接着通过次若干次尝试往上跳,让 x x x和 y y y的深度相同。这时候如果 x x x和 y y y已经是同一个结点了,就直接返回结果。否则,让两个点继续尝试同时往上跳一样的步数,注意只有在两个结点跳 2 j 2^j 2j次步后不相等时才会往上跳 换句话说,循环结束后, x x x和 y y y分别是它们 L C A LCA LCA的儿子 因此它们的父节点就是 L C A LCA LCA。
完整的正版LCA源代码
不知道大家发现了没有,这个方法就叫做倍增法。如果你想了解更多关于倍增法的知识,请观看这位dalao的博文
#include<bits/stdc++.h>
using namespace std;
const int MAX_N=1000100;
const int MAX_M=1000100;
struct Edge {
int v,next;
}e[MAX_M];
int p[MAX_N],eid;
void init() {
memset(p,-1,sizeof(p));
eid=0;
}
void insert(int u,int v) {
e[eid].v=v;
e[eid].next=p[u];
p[u]=eid++;
}
int f[MAX_N][20],d[MAX_N];
void dfs(int u) {
d[u]=d[f[u][0]]+1;
for (int i=p[u];i!=-1;i=e[i].next) {
int v=e[i].v;
if (v == f[u][0]) {
continue;
}
f[v][0]=u;
dfs(v);
}
}
int lca(int x,int y){
if(d[x]<d[y]){
swap(x,y);
}
int K=0;
while((1<<(K+1))<=d[x]){
K++;
}
for(int j=K;j>=0;j--){
if(d[f[x][j]]>=d[y]){
x=f[x][j];
}
}
if(x==y){
return x;
}
for(int j=K;j>=0;j--){
if(f[x][j]!=f[y][j]){
x=f[x][j];
y=f[y][j];
}
}
return f[x][0];
}
int main() {
int n,m,q;
scanf("%d%d%d",&n,&m,&q);
init();
for (int i=1;i<n;i++) {
int u,v;
scanf("%d%d",&u,&v);
insert(u,v);
insert(v,u);
}
dfs(q);
for(int j=1;(1<<j)<=n;j++){
for(int i=1;i<=n;i++){
f[i][j]=f[f[i][j-1]][j-1];
}
}
while (m--) {
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}