主席树简单来说就是线段树+前缀和,每个节点的线段树存的是他以及之前所有与他相关节点的线段树的信息和,现在询问u, v, 设u,v的lca是root, root的父节点是fa[root], 那么左节点的数量就是tVal[ lt[ t[u] ] ] + tVal[ lt[ t[v] ] ] - tVal[ lt[ t[ root ] ] ] - tVal[ lt[ t[fa[root] ] ] ]。
一开始我减了两次root的值,这样root的点就减了两次,就没了,所以就少了root这个点的值,以后写lca的时候要注意。
#include <bits/stdc++.h>
using namespace std;
const int maxn = 2*(int)1e6+100;
int n, m, a[maxn], b[maxn], len;
int tot = 0, tVal[maxn*40], t[maxn*40], lt[maxn*40], rt[maxn*40], fa[maxn];
vector< vector<int> > G(maxn);
struct LCA
{
#define type int
struct node{int to;type w;node(){}node(int _to,type _w):to(_to),w(_w){}};
type dist[maxn];
int path[maxn],dep[maxn],loc[maxn],len[maxn],LOG[maxn],all,n;
int dp[25][maxn], point[25][maxn]; //2^20 == 1e6 2^25 == 3e7
vector<node> G[maxn];
void dfs(int u, int now) {
path[++all] = u;
loc[u] = all;
dep[all] = now;
for (node cur : G[u]) {
int v = cur.to;
if (loc[v]) continue;
len[v] = now+1;
dist[v] = dist[u]+cur.w;
dfs(v, now+1);
path[++all] = u;
dep[all] = now;
}
}
void initRMQ(int n)
{
LOG[0] = -1;
for (int i = 1; i <= all; ++i) {
dp[0][i] = dep[i];
point[0][i] = path[i];
LOG[i] = ((i&(i-1)) == 0 ? LOG[i-1]+1 : LOG[i-1]);
}
for (int i = 1; (1<<i) <= all; ++i) {
for (int j = 1; j+(1<<i)-1 <= all; ++j) {
if (dp[i-1][j] < dp[i-1][j+(1<<(i-1))]) {
dp[i][j] = dp[i-1][j];
point[i][j] = point[i-1][j];
} else {
dp[i][j] = dp[i-1][j+(1<<(i-1))];
point[i][j] = point[i-1][j+(1<<(i-1))];
}
}
}
}
int queryLCA(int l,int r)
{
l = loc[l]; r = loc[r];
if(l>r) swap(l,r);
int k = LOG[r-l+1];
/*
貌似下面这种写法对于某些数据情况更快,对于某些数据也更慢- -
记得把上面预处理的LOG删了
P 3379
int k=0;
while((1<<k)<=r-l+1) k++;
k--;
*/
if(dp[k][l] < dp[k][r-(1<<k)+1]) return point[k][l];
else return point[k][r-(1<<k)+1];
}
type getDist(int a,int b){return dist[a]+dist[b]-2*dist[queryLCA(a,b)];}
int getLen(int a,int b){return len[a]+len[b]-2*len[queryLCA(a,b)];}
void init(int _n)
{
n = _n;
all = 0;
for(int i = 0;i <= n; i++)
{
loc[i] = 0;
dist[i] = 0;
len[i] = 0;
G[i].clear();
}
}
void addEdge(int a,int b,type w=1)
{
G[a].emplace_back(node(b,w));
G[b].emplace_back(node(a,w));
}
void solve(int root)
{
dfs(root, 1);
initRMQ(all);
}
#undef type
}lca;
int build(int l, int r) {
int node = ++tot;
int mid = (l + r) >> 1;
if (l < r) {
lt[node] = build(l, mid);
rt[node] = build(mid + 1, r);
}
return node;
}
int update(int l, int r, int par, int p) {
int node = ++tot;
lt[node] = lt[par]; rt[node] = rt[par]; tVal[node] = tVal[par] + 1;
int mid = (l + r) >> 1;
if (l < r) {
if (p <= mid) lt[node] = update(l, mid, lt[par], p);
else rt[node] = update(mid + 1, r, rt[par], p);
}
return node;
}
int query(int l, int r, int parpar, int par, int ql, int qr, int k) {
if (l == r) return l;
int mid = (l + r) >> 1;
int sum = 0;
sum = tVal[lt[ql]] + tVal[lt[qr]] - tVal[lt[par]] - tVal[lt[parpar]];
if (sum >= k) {
return query(l, mid, lt[parpar], lt[par], lt[ql], lt[qr], k);
}
else return query(mid + 1, r, rt[parpar], rt[par], rt[ql], rt[qr], k - sum); //k - sum !
}
void dfs(int x, int par) {
fa[x] = par;
int p = lower_bound(b + 1, b + 1 + len, a[x]) - b;
t[x] = update(1, len, t[par], p);
for (int to : G[x]) {
if (to == par) continue;
dfs(to, x);
}
}
int read() {
int ans = 0, f = 1; char c = getchar();
for (;c < '0' | c > '9'; c = getchar()) if (c == '-') f = -1;
for (;c >= '0' && c <= '9'; c = getchar()) ans = ans * 10 + c - '0';
return ans * f;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.precision(10);
cout << fixed;
#ifdef LOCAL_DEFINE
freopen("input.txt", "r", stdin);
#endif
memset(tVal, 0, sizeof(tVal));
n = read(); m = read();
lca.init(n);
for (int i = 1; i <= n; ++i) {
a[i] = read();
b[i] = a[i];
}
sort(b + 1, b + 1 + n);
len = unique(b + 1, b + 1 + n) - b - 1;
for (int i = 1; i <= n - 1; ++i) {
int u, v;
u = read(); v = read();
G[u].emplace_back(v);
G[v].emplace_back(u);
lca.addEdge(u, v, 1);
}
lca.solve(1);
t[0] = build(1, len);
dfs(1, 0);
for (int i = 1; i <= m; ++i) {
int u, v, k;
u = read(); v = read(); k = read();
int temp = lca.queryLCA(u, v);
int pos = query(1, len, t[fa[temp]], t[temp], t[u], t[v], k);
cout << b[pos] << '\n';
}
#ifdef LOCAL_DEFINE
cerr << "Time elapsed: " << 1.0 * clock() / CLOCKS_PER_SEC << " s.\n";
#endif
return 0;
}