在二叉树中找到两个节点的最近公共祖先(再进阶)
题目描述:
给定一棵二叉树,多次给出这棵树上的两个节点 o1 和 o2,请对于每次询问,找到 o1 和 o2 的最近公共祖先节点。
输入描述:
第一行输入两个整数 n 和 root,n 表示二叉树的总节点个数,root 表示二叉树的根节点。
以下 n 行每行三个整数 fa,lch,rch,表示 fa 的左儿子为 lch,右儿子为 rch。(如果 lch 为 0 则表示 fa 没有左儿子,rch同理)
第 n+2 行输入一个整数 m,表示询问的次数。
以下 m 行每行两个节点 o1 和 o2。
输出描述:
对于每组询问每行输出一个整数表示答案。
示例1
输入
8 1
1 2 3
2 4 5
4 0 0
5 0 0
3 6 7
6 0 0
7 8 0
8 0 0
4
4 5
5 2
6 8
5 8
输出
2
2
3
1
备注:
1 ≤ n ≤ 1 0 5 1 \leq n \leq 10^5 1≤n≤105
1 ≤ m ≤ 1 0 5 1 \leq m \leq 10^5 1≤m≤105
1 ≤ f a , l c h , r c h , r o o t , o 1 , o 2 ≤ n 1 \leq fa,lch,rch,root,o_1,o_2 \leq n 1≤fa,lch,rch,root,o1,o2≤n
o 1 ≠ o 2 o_1 \neq o_2 o1=o2
题解:
离线解法:
Tarjan + 并查集,通过并查集维护两个节点的最近公共祖先。参考 https://www.cnblogs.com/JVxie/p/4854719.html
注意:此题测试数据出现森林的情况,debug了个寂寞。。。
在线解法:
ST/倍增/树链剖分,以后用到再说吧,不在搞竞赛了,应该用不上。
代码:
#include <cstdio>
#include <vector>
#include <map>
using namespace std;
const int N = 100010;
typedef pair<int, int> PII;
vector<int> g[N];
vector<int> q[N];
int fa[N];
bool vis[N];
int n, rt;
int _fa, lch, rch;
int m;
int o1[N], o2[N];
int deg[N];
map<PII, int> ret;
int _find(int x) {
return fa[x] == x ? x : fa[x] = _find(fa[x]);
}
void _merge(int u, int v) {
u = _find(u);
v = _find(v);
if (u == v) return;
fa[v] = u;
}
void LCA(int root) {
vis[root] = true;
int sze = g[root].size();
for (int i = 0; i < sze; ++i) {
if (!vis[g[root][i]]) {
LCA(g[root][i]);
_merge(root, g[root][i]);
}
}
sze = q[root].size();
for (int i = 0; i < sze; ++i) {
if (vis[q[root][i]]) {
int l = min(root, q[root][i]);
int r = max(root, q[root][i]);
PII pii = {l, r};
if (!ret.count(pii)) ret[pii] = _find(q[root][i]);
}
}
}
int main(void) {
scanf("%d%d", &n, &rt);
for (int i = 1; i <= n; ++i) {
fa[i] = i;
scanf("%d%d%d", &_fa, &lch, &rch);
if (lch) {
g[_fa].push_back(lch);
deg[lch] += 1;
}
if (rch) {
g[_fa].push_back(rch);
deg[rch] += 1;
}
}
scanf("%d", &m);
for (int i = 0; i < m; ++i) {
scanf("%d%d", o1 + i, o2 + i);
q[o1[i]].push_back(o2[i]);
q[o2[i]].push_back(o1[i]);
}
for (int i = 1; i <= n; ++i) if (!deg[i]) LCA(i);
for (int i = 0; i < m; ++i) {
PII pii = {min(o1[i], o2[i]), max(o1[i], o2[i])};
printf("%d\n", ret[pii]);
}
return 0;
}