树上主席树
。。。。
理清一下思路就好了
注意第四个减去的要是lca的父亲。。。因为lca也算了进去
#include<bits/stdc++.h>
#define MAXN 1000005
typedef long long ll;
using namespace std;
ll n,m,a[MAXN],c[MAXN],dex,rt[MAXN],dep[MAXN],f[MAXN][25],tot,h[MAXN],js;
struct node{
ll val,num;
}b[MAXN];
struct node2{
int lc,rc,cnt;
}t[MAXN * 34];
struct node3{
int from,to,next;
}e[MAXN << 1];
void add(int x , int y){
tot++;
e[tot].from = x;
e[tot].to = y;
e[tot].next = h[x];
h[x] = tot;
}
bool cmp(node x , node y){
return x.val < y.val;
}
int build(int l , int r){
int zz = ++dex;
if(l == r)return zz;
int mid = (l + r) >> 1;
t[zz].lc = build(l , mid);
t[zz].rc = build(mid + 1 , r);
return zz;
}
int update(int op , int l , int r , int dx){
int zz = ++dex;
t[zz].lc = t[op].lc;
t[zz].rc = t[op].rc;
if(l == r){
t[zz].cnt++;
return zz;
}
int mid = (l + r) >> 1;
if(mid >= dx)t[zz].lc = update(t[op].lc , l , mid , dx);
else t[zz].rc = update(t[op].rc , mid + 1 , r , dx);
t[zz].cnt = t[t[zz].lc].cnt + t[t[zz].rc].cnt;
return zz;
}
int dfs(int now , int fa){
f[now][0] = fa;
dep[now] = dep[fa] + 1;
rt[now] = update(rt[fa] , 1 , js , a[now]);
for(int i = h[now] ; i != (-1) ; i = e[i].next){
if(e[i].to == fa)continue;
dfs(e[i].to , now);
}
}
int lca(int x , int y){
if(dep[x] < dep[y])swap(x , y);
int dx = dep[x] - dep[y];
for(int i = 20 ; i >= 0 ; i--){
if(dx & (1 << i))x = f[x][i];
}
if(x == y)return x;
for(int i = 20 ; i >= 0 ; i--){
if(f[x][i] != f[y][i]){
x = f[x][i];
y = f[y][i];
}
}
return f[x][0];
}
int que(int op1 , int op2 , int op3 , int op4 , int dx , int l , int r){
if(l >= r)return l;
int zz = t[t[op1].lc].cnt + t[t[op2].lc].cnt - t[t[op3].lc].cnt - t[t[op4].lc].cnt;
int mid = (l + r) >> 1;
if(zz >= dx)return que(t[op1].lc , t[op2].lc , t[op3].lc , t[op4].lc , dx , l , mid);
return que(t[op1].rc , t[op2].rc , t[op3].rc , t[op4].rc , dx - zz , mid + 1 , r);
}
int main(){
cin>>n>>m;memset(h , -1 , sizeof(h));
for(int i = 1 ; i <= n ; i++){
cin>>a[i];
c[i] = b[i].val = a[i];
b[i].num = i;
}
sort(b + 1 , b + 1 + n , cmp);
ll last = (-999999999);
for(int i = 1 ; i <= n ; i++){
if(b[i].val != last){
js++;
c[js] = b[i].val;
last = b[i].val;
}
a[b[i].num] = js;
}
rt[0] = build(1 , js);
for(int i = 1 ; i < n ; i++){
int x,y;cin>>x>>y;
add(x , y);
add(y , x);
}
dfs(1 , 0);
for(int i = 1 ; i <= 20 ; i++){
for(int j = 1 ; j <= n ; j++){
f[j][i] = f[f[j][i - 1]][i - 1];
}
}
last = 0;
for(int i = 1 ; i <= m ; i++){
ll x , y , z , p;cin>>x>>y>>z;
x = x ^ last;
p = lca(x , y);
last = que(rt[x] , rt[y] , rt[p] , rt[f[p][0]] , z , 1 , js);
last = c[last];
cout<<last<<endl;
}
}