CSDN代码云盘
要注意的是unique后返回的是开区间。
类似于一维主席树,这里的前缀是
r
o
o
t
[
u
]
−
r
o
o
t
[
v
]
−
r
o
o
t
[
l
c
a
(
u
,
v
)
]
+
r
o
o
t
[
f
a
[
l
c
a
(
u
,
v
)
]
]
root[u]-root[v]-root[lca(u,v)]+root[fa[lca(u,v)]]
root[u]−root[v]−root[lca(u,v)]+root[fa[lca(u,v)]],因为包含lca(u,v)这个点,所以减去的话只减一次,然后再减去父亲。
#include <bits/stdc++.h>
using namespace std;
#define FOR0(a,b) for(int i = a; i < b; ++i)
#define FORE(a,b) for(int i = a; i <= b; ++i)
typedef long long ll;
typedef pair<int,int> pii;
const int maxn = 2e5+5;
const int MX = 18;
struct node {
int l,r,sum;
}T[maxn*40];
int n,m,a[maxn],cnt, root[maxn], tot, ver[maxn],b[maxn];
int dep[maxn], fa[maxn], p[maxn][20],len;
vector<int> G[maxn];
void add(int u,int v) {
G[u].push_back(v);
}
int getid(int x) {
return lower_bound(b+1,b+1+len, x)-b;
}
void build(int l,int r,int& rt) {
rt = ++cnt;
T[rt].sum = 0;
if(l == r) return;
int mid = (l+r)>>1;
build(l,mid, T[rt].l);
build(mid+1,r,T[rt].r);
}
void update(int &x, int y, int pos, int l,int r) {
T[++cnt] = T[y]; T[cnt].sum++; x = cnt;
if(l == r) return;
int mid = (l+r)>>1;
if(mid >= pos)
update(T[x].l, T[y].l, pos, l, mid);
else
update(T[x].r, T[y].r, pos, mid+1,r);
}
int query(int x,int y,int lca, int lca_fa, int k, int l,int r) {
if(l == r) return l;
int mid = (l+r) >> 1;
int t = T[T[x].l].sum+T[T[y].l].sum-T[T[lca].l].sum-T[T[lca_fa].l].sum;
if(t >= k)
return query(T[x].l, T[y].l, T[lca].l, T[lca_fa].l, k, l,mid);
else
return query(T[x].r, T[y].r, T[lca].r, T[lca_fa].r, k-t, mid+1,r);
}
void dfs(int u,int f, int d) {
dep[u] = d; fa[u] = f;
update(root[u], root[f],getid(a[u]),1, len);
for(int i = 0; i < G[u].size(); ++i) {
int v = G[u][i];
if(v == f) continue;
p[v][0] = u;
dfs(v,u,d+1);
}
}
void initLCA() {
for(int j = 1; j <= MX; ++j)
for(int i = 1; i <= n; ++i)
p[i][j] = p[p[i][j-1]][j-1];
}
int LCA(int a,int b) {
if(dep[a] < dep[b])
swap(a,b);
for(int i = MX; i >= 0; --i) {
if(dep[p[a][i]] >= dep[b]) {
a = p[a][i];
}
}
if(a == b) return a;
for(int i = MX; i >= 0; --i) {
if(p[a][i] != p[b][i]) {
a = p[a][i];
b = p[b][i];
}
}
return p[a][0];
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 1;i <= n; ++i) {
scanf("%d", &a[i]);
b[i] = a[i];
}
sort(b+1, b+1+n);
// 这里要减1,否则是[)开区间
len = unique(b+1,b+1+n)-b-1;
// cout << len << endl;
int u,v,k;
for(int i = 0; i < n-1; ++i) {
scanf("%d%d", &u, &v);
add(u,v); add(v,u);
}
build(1,len,root[0]);
dfs(1,0,1);
initLCA();
for(int i = 0; i < m; ++i) {
scanf("%d%d%d",&u, &v, &k);
int lca = LCA(u,v);
printf("%d\n", b[query(root[u],root[v],root[lca],root[fa[lca]],k,1,len)]);
}
return 0;
}