Codeforces 832D
题目链接:http://codeforces.com/problemset/problem/832/D
题意:在一棵生成树上,给出了三个点,求三个点之间最大的相交点数。
分析:三个点没有规定起止点,所以求得就是其中两个点到第三个点的路径上的重复点数。(图有点丑 见谅)
也就相当于AB,BC,AC三条路径中的公共交点到ABC三点的最大边数。分别求出边数然后取最大值再+1就可以了,在得出具体的计算公式之后这应该算一道倍增/LCA模板题(划掉) 。
需要注意的是倍增求LCA需要bfs/dfs构造深度,再预处理祖先数组,这也是算法的核心,在这里维护一个数组DP[i][j],表示下标以i为起点的长度为2^j的序列的信息。倍增法中的DP[i][j]为:结点 i 的向上 2^j 层的祖先。其中,DP[i][0]为节点i的父节点。
递推方程: DP[i][j] = DP[ DP[i][j-1] ] [j-1]。
DP[i][j-1]是结点i往上跳2^(j-1) 层的祖先, DP[i][j-1] 上再向上跳2^ (j-1)层,相当于从结点i,先跳2^ (j-1)层,再跳2^ (j-1)层,最终到达2^j层。
void bfs() {
queue<int> que;
h[1] = 1;
que.push(1);
while (!que.empty()) {
int u = que.front();
que.pop();
for (int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].v;
if (h[v])
continue;
h[v] = h[u] + 1;
//求结点所在的深度
fa[v][0] = u;
//给fa数组初始化每个结点的父节点
que.push(v);
}
}
}
for (j = 1; (1 << j) <= n; j++)
for (i = 1; i <= n; i++)
if (fa[i][j - 1])
fa[i][j] = fa[fa[i][j - 1]][j - 1];
//dp数组
LCA详解:
int lca(int u, int v) {
if (h[u] < h[v])
swap(u, v);
//确保u的深度大于v
for (i = 20; i >= 0; i--)
if (h[u] - (1 << i) >= h[v])
u = fa[u][i];
//往上跳,深度--,使两者先处于同一个深度
if (u == v)
return u;
//如果已经处在同一个结点上直接return
for (i = 20; i >= 0; i--) {
//从大到小遍历回溯的高度,如果没有相遇就同时向上跳
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
}
//最后一次会跳到以v,u的lca为父节点的两个节点上,返回此时u/v的父节点即可
return fa[u][0];
}
计算(枚举)距离:
int dis(int u, int v) {
return h[u] + h[v] - 2 * h[lca(u, v)];
//相当于两个点的深度相加减去两倍的相交距离(lca到根节点的距离)
}
求(最大)重合距离:
int dis1 = dis(u, v), dis2 = dis(u, t), dis3 = dis(v, t);
int maxx = max((dis1 + dis2 - dis3) / 2, max((dis1 + dis3 - dis2) / 2, (dis2 + dis3 - dis1) / 2));
全部代码:
#include <algorithm> //swap
#include <iostream>
#include <cstring>
#include <map>
#include <queue>
using namespace std;
typedef long long ll;
const int inf = 0x3f3f3f3f;
const int maxn = 200010;
int i, j, k;
int m, n, q;
int head[maxn], h[maxn];
int fa[maxn][25];
int index;
struct node {
int v, next;
} e[maxn];
void add(int u, int v) {
e[index].v = v;
e[index].next = head[u];
head[u] = index++;
}
void bfs() {
queue<int> que;
h[1] = 1;
que.push(1);
while (!que.empty()) {
int u = que.front();
que.pop();
for (int i = head[u]; i != -1; i = e[i].next) {
int v = e[i].v;
if (h[v])
continue;
h[v] = h[u] + 1;
fa[v][0] = u;
que.push(v);
}
}
}
int lca(int u, int v) {
if (h[u] < h[v])
swap(u, v);
for (i = 20; i >= 0; i--)
if (h[u] - (1 << i) >= h[v])
u = fa[u][i];
if (u == v)
return u;
for (i = 20; i >= 0; i--) {
if (fa[u][i] != fa[v][i])
u = fa[u][i], v = fa[v][i];
}
return fa[u][0];
}
int dis(int u, int v) {
return h[u] + h[v] - 2 * h[lca(u, v)];
}
int main() {
while (cin >> n >> q) {
int u, v, t;
memset(head, -1, sizeof head);
memset(h, 0, sizeof h);
memset(fa, 0, sizeof fa);
index = 0;
for (i = 2; i <= n; i++) {
cin >> u;
add(u, i);
add(i, u);
}
bfs();
for (j = 1; (1 << j) <= n; j++)
for (i = 1; i <= n; i++)
if (fa[i][j - 1])
fa[i][j] = fa[fa[i][j - 1]][j - 1];
while (q--) {
scanf("%d %d %d", &u, &v, &t);
int dis1 = dis(u, v), dis2 = dis(u, t), dis3 = dis(v, t);
int maxx = max((dis1 + dis2 - dis3) / 2, max((dis1 + dis3 - dis2) / 2, (dis2 + dis3 - dis1) / 2));
cout << maxx + 1 << endl;
}
}
return 0;
}