D - Misha, Grisha and Underground【LCA倍增】
题意:
给你一棵n个结点无向树。我们假定根为结点1。给你三个点,求任意两点到另外一点的最大重叠路径 + 1。
思路:
1.如果知道LCA,那这题就接近是裸题了。
2.由于有多次查询,所以要用二进制优化的离线算法。
3.可以简单证明三个点必有两个点的LCA是一样的。
4.答案就是深度大的LCA到三点距离的最大值。这个通过画图可以清晰的知道结果。
代码:
#include <bits/stdc++.h>
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
using namespace std;
typedef long long LL;
typedef pair<int,int>pii;
const int maxn = 1e5;
vector<int>mp[maxn + 10];
int up[maxn + 10][30];
int depth[maxn + 10];
int n, q, x;
//初始化各个点的深度,以及各个点的父亲,即up[child][0] = father,向上走2^0步
void dfs(int root) {
for(int i = 0; i < mp[root].size(); ++ i) {
int to = mp[root][i];
if(to == up[root][0]) continue;
depth[to] = depth[root] + 1;
up[to][0] = root;
dfs(to);
}
}
//根据up[child][0]递推出up数组
void init() {
for(int i = 1; i <= 20; ++ i) {
for(int j = 1; j <= n; ++ j) {
up[j][i] = up[ up[j][i - 1] ][i - 1];
}
}
}
//求LCA
int lca(int x, int y) {
if(depth[x] < depth[y]) swap(x, y); //x比较深
//深度之差
int h = depth[x] - depth[y];
//用二进制优化,向上走h层,之后x,y深度相同
for(int i = 0; i <= 20; ++ i) {
if(h & (1 << i))
x = up[x][i];
}
//x == y,那就说明深度小的是LCA
if(x == y) return x;
//从极限跳到两者的父节点相同,那么该节点就是LCA
//注意不能从最小的开始跳,因为那样不能保证LCA,因为可能跳到LCA以上的点,那已经是重叠的部分了,不是LCA
for(int i = 20; i >= 0; -- i) {
if(up[x][i] != up[y][i]) {
x = up[x][i];
y = up[y][i];
}
}
//返回父节点
return up[x][0];
}
int dis(int u,int v) {
//用纸画一下就能知道
return depth[u] + depth[v] - 2 * depth[ lca(u, v) ];
}
int main() {
memset(up, 0, sizeof(up));
scanf("%d%d", &n, &q);
for(int i = 2; i <= n; ++ i) {
scanf("%d", &x);
mp[x].push_back(i);
mp[i].push_back(x);
}
//初始化深度
depth[1] = 0;
//走一遍初始化up数组和depth数组
dfs(1);
//迭代更新up数组
init();
int a, b, c;
while(q--) {
scanf("%d%d%d", &a, &b, &c);
int ab = lca(a, b);
int bc = lca(b, c);
int ac = lca(a, c);
//画图可知,深度最大的点到另外三个点的距离是结果
if(depth[ab] < depth[bc]) swap(ab, bc);
if(depth[ab] < depth[ac]) swap(ab, ac);
int ans = -1;
ans = max(ans, dis(ab, a) + 1);
ans = max(ans, dis(ab, b) + 1);
ans = max(ans, dis(ab, c) + 1);
cout << ans << endl;
}
}
另解:
其实,本题答案即为以下三种情况的最大值:
借助LCA求两点之间距离,然后维护最大值即可。
盗一张图帮助理解
while(q--) {
scanf("%d%d%d", &a, &b, &c);
//另外一种思路
int x = (dis(a, b) + dis(b, c) - dis(a, c) ) / 2 + 1;
int y = (dis(a, b) + dis(a, c) - dis(b, c) ) / 2 + 1;
int z = (dis(a, c) + dis(b, c) - dis(a, b) ) / 2 + 1;
cout << max(x, max(y, z)) << endl;
}