题意
给出一棵有边权的树,然后给出q个查询,每次查询问两个结点的路径上的边的长度的中位数是多少。
思路
这道题目是用主席树(用权值当结点)和LCA来做的。
和之前做过的区间第K大类似,这道题目是把数组转化为树。儿子结点的线段树信息是继承了父亲结点的线段树信息(数组中是第i个结点继承了第i-1个结点的信息)。这样,在查询操作的时候,我们可以找出两点的LCA,然后借助这种思想,在查询操作的时候,就和区间查询一样(区间查询是右端点的信息减去左端点-1的信息),把两个点的线段树信息相加再减去两倍LCA的线段树信息,就可以得到两点的路径的信息了,然后就转化为对这个路径求第k大了。
比如这幅图里面,2号点继承了1号点的信息,3号点继承了2号点的信息。要查询4号点和5号点之间的信息,这两个点的LCA是2号点。那么查询的时候就是4号点的权值+5号点的权值-两倍2号点的权值。
#include <bits/stdc++.h>
using namespace std;
#define N 100010
struct Edge {
int u, v, w, nxt;
} edge[N*2];
struct Node {
int l, r, sum;
} tree[N*40];
int root[N], cnt, head[N], tot, n, mx, dp[N][25], dep[N], fa[N], dis[N];
void Add(int u, int v, int w) {
edge[tot] = (Edge) {u, v, w, head[u]}, head[u] = tot++;
edge[tot] = (Edge) {v, u, w, head[v]}, head[v] = tot++;
}
int query(int left, int right, int f, int l, int r, int k) {
if(l == r) return l;
int m = (l + r) >> 1;
int sum = tree[tree[right].l].sum + tree[tree[left].l].sum - 2 * tree[tree[f].l].sum;
if(k <= sum) return query(tree[left].l, tree[right].l, tree[f].l, l, m, k);
else return query(tree[left].r, tree[right].r, tree[f].r, m + 1, r, k - sum);
}
void update(int pre, int &rt, int l, int r, int x) {
tree[++cnt] = tree[pre];
rt = cnt; tree[rt].sum++;
if(l == r) return ;
int m = (l + r) >> 1;
if(x <= m) update(tree[pre].l, tree[rt].l, l, m, x);
else update(tree[pre].r, tree[rt].r, m + 1, r, x);
}
void dfs(int u, int f) { // 一边更新主席树一边得到lca需要的信息
dp[u][0] = fa[u];
for(int i = 1; i <= 20; i++) dp[u][i] = dp[dp[u][i-1]][i-1];
for(int i = head[u]; ~i; i = edge[i].nxt) {
int v = edge[i].v, w = edge[i].w;
if(v == f) continue;
fa[v] = u; dep[v] = dep[u] + 1; dis[u] = dis[v] + w;
update(root[u], root[v], 1, mx, w);
dfs(v, u);
}
}
int LCA(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = 20; i >= 0; i--)
if(dep[dp[x][i]] >= dep[y]) x = dp[x][i];
if(x == y) return x;
for(int i = 20; i >= 0; i--)
if(dp[x][i] != dp[y][i]) x = dp[x][i], y = dp[y][i];
return dp[x][0];
}
int main() {
int t; scanf("%d", &t);
while(t--) {
scanf("%d", &n); mx = 0;
memset(dp, 0, sizeof(dp));
memset(dis, 0, sizeof(dis));
memset(dep, 0, sizeof(dep));
memset(head, -1, sizeof(head)); tot = cnt = 0;
for(int i = 1; i < n; i++) {
int u, v, w; scanf("%d%d%d", &u, &v, &w);
Add(u, v, w);
if(w > mx) mx = w;
}
fa[1] = 1;
dfs(1, 0);
int q; scanf("%d", &q);
while(q--) {
int a, b; scanf("%d%d", &a, &b);
int f = LCA(a, b); // LCA
int p = dep[a] + dep[b] - 2 * dep[f]; // a和b路径的长度
double ans = 0;
if(p % 2) ans = (double)query(root[a], root[b], root[f], 1, mx, p / 2 + 1);
else ans = ((double)query(root[a], root[b], root[f], 1, mx, p / 2) + (double)query(root[a], root[b], root[f], 1, mx, p / 2 + 1)) / 2.0;
printf("%.1f\n", ans);
}
}
return 0;
}