一、定义
- LCA(Lowest Common Ancestors),即最近公共祖先,是指在有根树中,找出某两个结点u和v最近的公共祖先。
公共祖先是什么?对于x,y。如果z既是x的祖先也是y的祖先,那么我们就称z是x和y的公共祖先。
如上图,结点4,6的公共祖先有1、2,
但最近的公共祖先是2,即Lca(4,6) = 2
二、求法
思想:先让u,v中深度大的往上走,直到u,v深度相同,若此时u==v,则已找到。再让u,v一起往上走,直到走到同一个结点
时间复杂度:O(n)
暴力求解法时间复杂度其实也还可以接受,但是如果对于多组样例就不如tarjan了。
那么具体代码如下:
I.思想:注意到u,v走到最近公共祖先w之前,u,v所在结点不相同。而到达最近公共祖先w后,再往上走仍是u,v的公共祖先,即u,v走到同一个结点,这具有二分性质。于是可以预处理出一个 2 k 2^k 2k的表,fa[k][u]表示u往上走 2 k 2^k 2k步走到的结点,令根结点深度为0,则 2 k 2^k 2k>depth[u]时,令fa[k][u]=-1(不合法情况的处理)
不妨假设depth[u] < depth[v]
①将v往上走d = depth[v] - depth[u]步,此时u,v所在结点深度相同,该过程可用二进制优化。由于d是确定值,将d看成2的次方的和值,
d
=
2
k
1
+
2
k
2
+
.
.
.
+
2
k
m
d = 2^{k1} + 2^{k2} + ... + 2^{km}
d=2k1+2k2+...+2km,利用fa数组,如
v
=
f
a
[
k
1
]
[
v
]
v = fa[k1][v]
v=fa[k1][v],
v
=
f
a
[
k
2
]
[
v
]
v = fa[k2][v]
v=fa[k2][v]进行加速上升
②若此时
u
=
v
u = v
u=v,说明Lca(u,v)已找到
③利用fa数组加速u,v一起往上走到最近公共祖先w的过程。令
d
=
d
e
p
t
h
[
u
]
−
d
e
p
t
h
[
w
]
d = depth[u] - depth[w]
d=depth[u]−depth[w],虽然d是个未知值,但依然可以看成2的次方的和。从高位到低位枚举d的二进制位,设最低位为第0位,若枚举到第k位,有
f
a
[
k
]
[
u
]
!
=
f
a
[
k
]
[
v
]
fa[k][u] != fa[k][v]
fa[k][u]!=fa[k][v],则令
u
=
f
a
[
k
]
[
u
]
u = fa[k][u]
u=fa[k][u],
v
=
f
a
[
k
]
[
v
]
v = fa[k][v]
v=fa[k][v]。最后最近公共祖先
w
=
f
a
[
0
]
[
u
]
=
f
a
[
0
]
[
v
]
w = fa[0][u] = fa[0][v]
w=fa[0][u]=fa[0][v],即u和v的父亲.
II.那么我们接下来想如何预处理?
解法:
k=0时,
f
a
[
k
]
[
u
]
fa[k][u]
fa[k][u]为u在有根树中的父亲,令根结点
f
a
[
k
]
[
r
o
o
t
]
=
−
1
fa[k][root]=-1
fa[k][root]=−1。
k>0时,
f
a
[
k
]
[
u
]
=
f
a
[
k
−
1
]
[
f
a
[
k
−
1
]
[
u
]
]
fa[k][u]=fa[k-1][fa[k-1][u]]
fa[k][u]=fa[k−1][fa[k−1][u]]。树的高度最多为
n
n
n,k是
l
o
g
(
n
)
log(n)
log(n)级别。
III.复杂度:
预处理O(nlogn)
单次查询O(logn)
那么具体代码如下:
#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
using namespace std;
int t;
int fa[10001];
int d[30001];
struct node{
int y,v,Next;
}e[50001];
int Fa[10010][25];
int n,m;
int len=0;
int linkk[10010];
bool vis[20001];
int root;
void insert(int x,int y,int v){
e[++len].Next=linkk[x];
linkk[x]=len;
e[len].v=v;
e[len].y=y;
}
void dfs(int now,int de){
if (vis[now]) return;
vis[now]=1;
if (d[now]==0&&now!=root) d[now]=de;else d[now]=min(d[now],de);
for(int i=linkk[now];i;i=e[i].Next){
int y=e[i].y;
if (y==Fa[now][0]) continue;
Fa[y][0]=now;
dfs(y,de+1);
}
}
void find_Fa(){
for (int j=1;(1<<j)<n;j++)
for (int i=1;i<=n;i++)
if (Fa[i][j-1]==-1) Fa[i][j]=-1;
else Fa[i][j]=Fa[Fa[i][j-1]][j-1];
}
int lca(int u,int v){
if (d[u]>d[v]) swap(u,v);
for (int dd=d[v]-d[u],i=0;dd;dd>>=1,i++)
if (dd&1) v=Fa[v][i];
if (u==v)return u;
for (int i=24;i>=0;i--)
if (Fa[u][i]!=Fa[v][i]) u=Fa[u][i],v=Fa[v][i];
return Fa[u][0];
}
int main(){
scanf("%d",&t);
while (t--){
int st,ed;
memset(vis,0,sizeof(vis));
memset(d,0,sizeof(d));
len=0;
memset(linkk,0,sizeof(linkk));
memset(Fa,0,sizeof(Fa));
memset(fa,0,sizeof(fa));
scanf("%d",&n);
for (int i=1,x,y;i<n;i++) scanf("%d %d",&x,&y),fa[y]=x,insert(x,y,1),insert(y,x,1);
scanf("%d %d",&st,&ed);
for (int i=1;i<=n;i++) if (!fa[i]){root=i;break;}
Fa[root][0]=-1;
dfs(root,0);
find_Fa();
printf("%d\n",lca(st,ed));
}
}
I.离线与在线的区别:
离线算法就是先把所有询问存起来,一次处理完,最后输出。
而在线算法就是即询问即计算,前面两个算法都是在线算法。
II.思想:
Tarjan算法基于这样一个事实,要找w=Lca(u,v),在dfs遍历完u到遍
历完v的过程中,遍历到v时,u到w路径上除w外结点的子树都遍历
过了,w的子树还未遍历完。如果对于结点u,访问完它的子树后就
把u在并查集中的父亲设为它在树中的父亲,那么访问到v时u在并
查集中的父亲就是Lca(u,v)。
那么具体代码如下:
#include<cstdio>
#include<cstring>
#include<string>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<map>
#include<queue>
using namespace std;
#define mp make_pair
typedef pair < int , int > pii;
int root;
int t;
int fa[40010];
vector < pii > e[80020];
vector < pii > id[80020];
bool vis[40010];
int n,m;
int len=0;
int ans[40010];
int a[40010];
int d[40010];
int getfa(int k){
return k==fa[k]?k:fa[k]=getfa(fa[k]);
}
void tarjan(int u){
vis[u]=1;
for (int i=0;i<e[u].size();i++){
int y=e[u][i].first;
if (vis[y]) continue;
tarjan(y);
fa[y]=u;
}
for (int i=0;i<id[u].size();i++)
if (vis[id[u][i].second])
ans[id[u][i].first]=d[u]+d[id[u][i].second]-2*d[getfa(id[u][i].second)];
}
void dfs(int u,int de){
if (vis[u]) return;
vis[u]=1;
d[u]=de;
for (int i=0;i<e[u].size();i++){
int y=e[u][i].first;
dfs(y,de+e[u][i].second);
}
}
int main(){
scanf("%d",&t);
while (t--){
memset(vis,0,sizeof(vis));
memset(ans,0,sizeof(ans));
memset(d,0,sizeof(d));
memset(a,0,sizeof(a));
memset(vis,0,sizeof(vis));
scanf("%d %d",&n,&m);
for (int i=1;i<=n;i++) e[i].clear();
for (int i=1;i<=m;i++) id[i].clear();
for (int i=1,x,y,z;i<n;i++) scanf("%d %d %d",&x,&y,&z),a[y]=x,e[x].push_back(mp(y,z)),e[y].push_back(mp(x,z));
for (int i=1,x,y;i<=m;i++) scanf("%d %d",&x,&y),id[x].push_back(mp(i,y)),id[y].push_back(mp(i,x));
for (int i=1;i<=n;i++) if (!a[i]){root=i;break;}
dfs(root,0);
memset(vis,0,sizeof(vis));
for (int i=1;i<=n;i++) fa[i]=i;
tarjan(root);
for (int i=1;i<=m;i++)
printf("%d\n",ans[i]);
}
return 0;
}
具体例题请看我的博客题解