题意:
题目链接:https://vjudge.net/problem/SPOJ-COT
对于给出的一棵树,给出若干询问,问任意两个节点的路径上第k小的节点的权值是多少?
思路:
主席树的另一种经典问题。
看似与HDU 2665不同,这里是针对树结构,其实思路基本不变。只是在建立主席树的时候,需要根据树的父子关系来构造,也就是节点u的线段树要在u的父亲基础上构建。
另外,在寻找第k大时,需要考虑x到y的路径,也就是在线段树rt[x]+rt[y]-rt[lca(x,y)]-rt[pa[lca(x,y)]]上寻找第k大。
代码:
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e5 + 10;
struct node {
int ls, rs, sum;
} ns[MAXN * 20];
int ct, rt[MAXN * 20];
void cpy(int& now, int old) {
now = ++ct;
ns[now] = ns[old];
}
void build(int& now, int l, int r) {
now = ++ct;
ns[now].sum = 0;
if (l == r) return;
int m = (l + r) >> 1;
build(ns[now].ls, l, m);
build(ns[now].rs, m + 1, r);
}
void update(int& now, int old, int l, int r, int x) {
cpy(now, old);
ns[now].sum++;
if (l == r) return;
int m = (l + r) >> 1;
if (x <= m) update(ns[now].ls, ns[old].ls, l, m, x);
else update(ns[now].rs, ns[old].rs, m + 1, r, x);
}
int query(int s, int t, int lca, int flca, int l, int r, int k) {
if (l == r) return l;
int m = (l + r) >> 1;
//cout << ns[ns[s].ls].sum << " " << ns[ns[t].ls].sum << " " << ns[ns[lca].ls].sum << " " << ns[ns[flca].ls].sum <<endl;
int cnt = ns[ns[s].ls].sum + ns[ns[t].ls].sum - ns[ns[lca].ls].sum - ns[ns[flca].ls].sum;
//cout << s << ", " << t << ", " << lca << ", " << flca << ", " << " cnt = " << cnt << ", " << l << ", " << r << ", " << k << endl;
if (k <= cnt) return query(ns[s].ls, ns[t].ls, ns[lca].ls, ns[flca].ls, l, m, k);
return query(ns[s].rs, ns[t].rs, ns[lca].rs, ns[flca].rs, m + 1, r, k - cnt);
}
int dfs_cnt, sz;
int pa[MAXN], dp[2 * MAXN][20], fir[2 * MAXN], ver[2 * MAXN], R[2 * MAXN];
int a[MAXN], b[MAXN];
vector <int> tree[MAXN];
void dfs(int u, int fa, int deep) {
pa[u] = fa;
fir[u] = ++dfs_cnt;
ver[dfs_cnt] = u; R[dfs_cnt] = deep;
update(rt[u], rt[fa], 1, sz, a[u]);
for (int i = 0; i < (int)tree[u].size(); i++) {
int v = tree[u][i];
if (v == fa) continue;
dfs(v, u, deep + 1);
ver[++dfs_cnt] = u; R[dfs_cnt] = deep;
}
}
void ST(int n) {
for (int i = 1; i <= n; i++) dp[i][0] = i;
for (int j = 1; (1 << j) <= n; j++) {
for (int i = 1; i + (1 << j) - 1 <= n; i++) {
int x = dp[i][j - 1], y = dp[i + (1 << (j - 1))][j - 1];
dp[i][j] = R[x] < R[y] ? x : y;
}
}
}
int RMQ(int l, int r) {
int k = 0;
while ((1 << (k + 1)) <= r - l + 1) ++k;
int x = dp[l][k], y = dp[r - (1 << k) + 1][k];
return R[x] < R[y] ? x : y;
}
int LCA(int u, int v) {
int x = fir[u], y = fir[v];
if (x > y) swap(x, y);
return ver[RMQ(x, y)];
}
void init(int n) {
ct = dfs_cnt = 0;
build(rt[0], 1, sz);
for (int i = 1; i <= n; i++) tree[i].clear();
}
int main() {
//freopen("in.txt", "r", stdin);
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
b[i] = a[i];
}
sort (b + 1, b + 1 + n);
sz = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n; i++) {
a[i] = lower_bound(b + 1, b + 1 + sz, a[i]) - b;
}
for (int i = 1; i <= n; i++) tree[i].clear();
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
tree[u].push_back(v);
tree[v].push_back(u);
}
ct = dfs_cnt = 0;
build(rt[0], 1, sz);
dfs(1, 0, 1);
ST(2 * n - 1);
/*for (int i = 0; i <= 5 * n; i++) {
printf("%d, rt = %d, ls = %d, rs = %d, sum = %d\n", i, rt[i], ns[rt[i]].ls, ns[rt[i]].rs, ns[rt[i]].sum);
}*/
while (m--) {
int s, t, k;
scanf("%d%d%d", &s, &t, &k);
int lca = LCA(s, t);
int flca = pa[lca];
//cout << s << " " << t << " " << lca << " " << flca << endl;
printf("%d\n", b[query(rt[s], rt[t], rt[lca], rt[flca], 1, sz, k)]);
}
return 0;
}