LCA(最近公共祖先)(倍增 + tarjan)
LCA:
LCA(Least Common Ancestors),即最近公共祖先。
祖先结点:从一个结点出发,一直往上走,直至根节点,这条路上所有的结点都是该结点的祖先结点。所以自己也是自己的祖先。
举个例子:对于图中3、4结点LCA(3, 4) = 2; 图中5、6结点LCA(5, 6) = 5; 图中4,5结点LCA(4, 5) = 1;
倍增:
倍增求LCA又称最大跳跃求LCA。
先介绍下倍增思想:
- 首先给定一棵树,根据相关边键图(这里建无向图或单向图都可行,本文建无向图)
- 其次,深度访问建好的图,从根节点出发,给每个结点的深度算出来,初始化关键数组 f a [ i ] [ j ] fa[i][j] fa[i][j],表示从i结点往根节点出发的第 2 j 2^j 2j个祖先结点, f a [ i ] [ 0 ] fa[i][0] fa[i][0]即i结点的父节点, 2 0 = = 1 2^0 == 1 20==1这个数组可以说是倍增思想的关键(RMQ也是这个原理)
- 查询LCA,给定x,y结点查找LCA(x, y)。查询函数的实现细节:先将两个节点通过 f a [ i ] [ j ] fa[i][j] fa[i][j]进行跳跃,知道两个节点跳跃到相同高度,这个跳跃是最大跳跃,每次通过 f a [ i ] [ j ] fa[i][j] fa[i][j]实现 2 j 2^j 2j次跳跃,相比一次跳跃一次快的多。当xy两节点跳跃到同一深度时,若此时xy结点重合了,那么说明xy此时便是要找的最近公共祖先(例如上图求LCA(6, 5),6往上跳跃到5, 此时5和6重合,LCA(5, 6) == 5)。若还没,例如上图LCA(3, 4),此都在同一深度,那么就需要两个同时往上跳,直到找到2,此时LCA(3, 4) == 2。
具体函数实现:
一、建图
for(int i = 1; i < n; i++) { //这里注意只有n-1条边,不要写错了,树的特点:点比边多一
int a, b; cin >> a >> b;
e[a].push_back(b);
e[b].push_back(a);
}
二、dfs初始化
void dfs(int x, int f) { //f是x的父节点
dep[x] = dep[f] + 1;
fa[x][0] = f;
for(int i = 1; i < 20; i++) fa[x][i] = fa[fa[x][i-1]][i-1]; // 这个写这里还是写主函数开头都可以
//i的范围与n的大小有关,一般MAX_i = log(n) + 1;
for(int y : e[x]) {
if(y == f) continue; // 建的是双向边,所有要判断,保证每个点只访问一次
dfs(y, x);
}
}
三、查询LCA
int lca(int x, int y) {
if(dep[x] < dep[y]) {
int t = x;
x = y;
y = t;
}
for(int i = 19; i >= 0; i--) {
if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
}
if(x == y) return x;
for(int i = 19; i >= 0; i--) {
if(fa[x][i] != fa[y][i]) { //注意这里细节,我们是不断条越到最近公共祖先的儿子结点下,最终返回fa[x][0]
//最终一定能调到该节点,因为该节点到x,y结点距离可以用一个数表示,这个数写成二进制例如6:110,可以先跳4(对应二进制100),然后跳2(对应二进制10)
x = fa[x][i]; y = fa[y][i];
}
}
return fa[x][0];
}
tarjan
tarjan的思路需要用到并查集,深度优先遍历。tarjan是一种离线算法,为什么叫做是离线的,因为tarjan需要知道你要查询哪些结点的lca,会提前对这些查询做个处理。
f a [ i ] fa[i] fa[i]数组,初始时都指向自己,这个数组和并查集相关。 v i s [ i ] vis[i] vis[i]数组是深度遍历用的,保证每个节点只访问一次
深度优先遍历的过程,我们每次回到当前节点时需要将子节点挂一条边到当前节点
从1出发 : dfs(1) -> dfs(2)
从2出发: dfs(2) -> dfs(3)
3结点没有子节点,递归结束,回到2,此时f[3] = 2, 如下图黑色边一样。每次离开某个节点时还要查询一下是否有和当前节点相关的LCA查询,例如右边的query(3, 4),此时离开3结点,回到2,那么需要对3进行一次判断。如何判断:若vis[3] == 1 && vis[4] != 1则不进行处理,若vis[3] == 1 && vis[4] == 1,则说明在当前深度下3、4同属一个祖先结点,通过find(3)一直找到结点2,此时2因为还没有回到1,所以fa[2] == 2自己本身,这里就是并查集的作用了,所以本次查询LCA(3, 4) == 2
代码实现:
#include<iostream>
#include<vector>
using namespace std;
const int N = 5e5 + 10;
typedef pair<int,int> pii;
vector<int> e[N];
vector<pii> query[N];
int fa[N], vis[N], ans[N];
int find(int x) {
if(fa[x] == x) return x;
return fa[x] = find(fa[x]);
}
void tarjan(int x) {
vis[x] = 1; //每次访问一个结点就打上标记
for(int y : e[x]) {
if(!vis[y]) {
tarjan(y);
fa[y] = x;
}
}
for(pii t : query[x]) { //离开当前节点时,要进行离线处理查询,若某个查询符合要求,就可以通过find找到祖先
int v = t.first, d = t.second;
if(vis[v]) ans[d] = find(v);
}
}
int main() {
int n, m, s; cin >> n >> m >> s;
for(int i = 1; i < n; i++) {// n-1条边
int x, y; cin >> x >> y;
e[x].push_back(y);
e[y].push_back(x);
}
for(int i = 1; i <= n; i++) fa[i] = i; //悲伤,不能在上面那个循环里面进行fa[i]赋值,因为上面的漏了个n,pwq...
for(int i = 1; i <= m; i++) {
int a, b; cin >> a >> b;
query[a].push_back({b, i});
query[b].push_back({a, i});
}
tarjan(s);
for(int i = 1; i <= m; i++) {
cout << ans[i] << '\n';
}
return 0;
}
tarjan处理的问题是需要提前知道查询哪些结点,是离线处理的,如果对于找一堆结点的LCA,是不能通过tarjan处理,下面举个例题:
J-尖塔第四强的高手_河南萌新联赛2024第(四)场:河南理工大学 (nowcoder.com)
在这个题目中,每次查询不是只给你两个节点x, y,而是给你一个结点集,让你求这个点集的LCA,对于这个,貌似不能用tarjan,因为x,y并不能确定下来,而应该用倍增思路来写,例如对于点集{x, y, z, q, p, t}, 应该求得ans = LCA(x, y), 再求ans = LCA(ans, y)…ans = LCA(ans, t); 所以说tarjan不能用来处理这中点集的lca,因为ans会一直变化,应该用在线处理。
题目中每次查询给的点集不会很大,因为需要满足 f i b [ i ] + x < = n ; ( n < = 1 e 5 ) fib[i] + x <= n;(n <= 1e5) fib[i]+x<=n;(n<=1e5)(斐波拉契数列) 我们容易求得fib[25]就已经超过1e5了,所以每次查询的点集最多25个结点,所以直接枚举所有点,依次进行lca查询即可
#include<iostream>
#include<vector>
using namespace std;
const int N = 5e5 + 10;
vector<int> fib(30, 0);
vector<int> e[N];
int dep[N], fa[N][20];
void dfs(int x, int f) {
dep[x] = dep[f] + 1;
fa[x][0] = f;
for(int i = 1; i < 20; i++) fa[x][i] = fa[fa[x][i-1]][i-1];
for(int y : e[x]) {
if(y == f) continue;
dfs(y, x);
}
}
int lca(int x, int y) {
if(dep[x] < dep[y]) {
int t = x;
x = y;
y = t;
}
for(int i = 19; i >= 0; i--) {
if(dep[fa[x][i]] >= dep[y]) x = fa[x][i];
}
if(x == y) return x;
for(int i = 19; i >= 0; i--) {
if(fa[x][i] != fa[y][i]) {
x = fa[x][i]; y = fa[y][i];
}
}
return fa[x][0];
}
int main() {
int n, r, m; cin >> n >> r >> m;
fib[0] = fib[1] = 1;
for(int i = 2; i < 30; i++) fib[i] = fib[i - 1] + fib[i - 2];
for(int i = 1; i < n; i++) {
int a, b; cin >> a >> b;
e[a].push_back(b);
e[b].push_back(a);
}
dfs(r, 0);
for(int d = 1; d <= m; d++) {
int x, k; cin >> x >> k;
if(k > 25) {
cout << 0 << '\n';
continue;
}
vector<int> num;
for(int i = k; i <= 25; i++) {
if(x + fib[i] <= n) num.push_back(x + fib[i]);
}
if(num.size() == 0) {
cout << 0 << '\n';
continue;
}
int ans = num[0];
for(int i = 1; i < num.size(); i++) {
ans = lca(ans, num[i]);
}
cout << ans << '\n';
}
return 0;
}