题目来源:http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3195
给定一棵树,查询连通任意三点的路径的总长度。
设给定的三点为x,y,z。三点之间的路径总长度=(lca(x,y)+lca(y,z)+lca(x,z))/2。
代码:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cstdlib>
#include <iostream>
#define ll long long
using namespace std;
const int maxn = 5e4 + 10;
const int maxh = 19;
int n, head[maxn], cnt, q, s[maxn], anc[maxn][maxh], lcs[maxn][maxh], dep[maxn];
struct edge {
int to, next, vi;
} e[maxn * 2];
void ins(int x, int y, int z) {
e[++cnt].to = y;
e[cnt].next = head[x];
e[cnt].vi = z;
head[x] = cnt;
}
void build(int root) {
int top = 0;
s[++top] = root;
dep[root] = 1;
for (int i = 0; i < maxh; ++i) {
anc[root][i] = root;
lcs[root][i] = 0;
}
while (top) {
int x = s[top--];
if (x != root)
for (int i = 1; i < maxh; ++i) {
anc[x][i] = anc[anc[x][i - 1]][i - 1];
lcs[x][i] = lcs[x][i - 1] + lcs[anc[x][i - 1]][i - 1];
}
for (int i = head[x]; i; i = e[i].next) {
int y = e[i].to;
if (y == anc[x][0])continue;
dep[y] = dep[x] + 1;
anc[y][0] = x;
lcs[y][0] = e[i].vi;
s[++top] = y;
}
}
}
int swim(int &x, int h) {
int tot = 0;
for (int i = 0; h > 0; ++i) {
if (h & 1) {
tot += lcs[x][i];
x = anc[x][i];
}
h = (h>>1);
}
return tot;
}
int lca(int x, int y) {
int tot = 0;
if (dep[x] > dep[y])swap(x, y);
tot += swim(y, dep[y] - dep[x]);
if (x == y)return tot;
while (1) {
int pos;
for (pos = 0; anc[x][pos] != anc[y][pos]; ++pos);
if (pos == 0) {
tot += lcs[x][0];
tot += lcs[y][0];
return tot;
}
tot += lcs[x][pos - 1];
x = anc[x][pos - 1];
tot += lcs[y][pos - 1];
y = anc[y][pos - 1];
}
}
int solve(int x,int y,int z) {
return (lca(x, y) + lca(y, z) + lca(x, z)) / 2;
}
int main() {
int _ = 0;
while (~scanf("%d", &n)) {
if (_)printf("\n");
else _ = 1;
int x, y, z;
cnt = 0;
memset(head, 0, sizeof(head));
memset(dep, 0, sizeof(dep));
memset(lcs, 0, sizeof(lcs));
for (int i = 1; i < n; ++i) {
scanf("%d%d%d", &x, &y, &z);
ins(x, y, z);
ins(y, x, z);
}
build(0);
scanf("%d", &q);
for (int i = 1; i <= q; ++i) {
scanf("%d%d%d", &x, &y, &z);
printf("%d\n", solve(x, y, z));
}
}
return 0;
}