题意:
求树上三点之间的最短距离
思路:
求三个点两两之间的最短距离,相加除以2便是结果
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cmath>
using namespace std;
const int maxn = 5e4 + 10;
int flag = 0, n, q, tot, head[maxn], dis[maxn];
int tim, id[maxn*2], F[maxn], R[maxn*2], dp[25][maxn*2];
struct Edge { int to, d, next; } edges[maxn*2];
void init() {
tot = 0;
for (int i = 0; i < n; i++) head[i] = -1;
dis[0] = 0;
tim = 0;
}
void addedge(int u, int v, int d) {
edges[++tot].to = v; edges[tot].d = d; edges[tot].next = head[u]; head[u] = tot;
}
void dfs(int u, int p, int dep) {
id[++tim] = u;
F[u] = tim;
R[tim] = dep;
for (int i = head[u]; ~i; i = edges[i].next) {
int v = edges[i].to, d = edges[i].d;
if (v == p) continue;
dis[v] = dis[u] + d;
dfs(v, u, dep+1);
id[++tim] = u;
R[tim] = dep;
}
}
void ST() {
for (int i = 1; i <= tim; i++) dp[0][i] = i;
int len = (int)log2(tim)+1;
for (int i = 1; i <= len; i++) {
for (int j = 1; j <= tim; j++) {
int x = dp[i-1][j], y = dp[i-1][j+(1<<(i-1))];
if (j+(1<<(i-1)) <= tim) dp[i][j] = (R[x]<R[y]?x:y);
else dp[i][j] = x;
}
}
}
int query(int l, int r) {
int k = (int)log2(r-l+1);
int x = dp[k][l], y = dp[k][r-(1<<k)+1];
return (R[x]<R[y]?x:y);
}
int lca(int x, int y) {
if (F[x] > F[y]) swap(x, y);
return id[query(F[x], F[y])];
}
int ans(int x, int y) {
return dis[x] + dis[y] - 2*dis[lca(x, y)];
}
int main() {
while (~scanf("%d", &n)) {
if (flag++) printf("\n");
init();
int u, v, d;
for (int i = 0; i < n-1; i++) {
scanf("%d%d%d", &u, &v, &d);
addedge(u, v, d); addedge(v, u, d);
}
dfs(0, -1, 1);
ST();
scanf("%d", &q);
while (q--) {
scanf("%d%d%d", &u, &v, &d);
printf("%d\n", (ans(u, v)+ans(v, d)+ans(d, u))/2);
}
}
return 0;
}