P2633 Count on a tree
与一般可持久化线段树相似。
这里的可持久化线段树不需要考虑区间的关系。
类别维护区间的可持久化线段树,查询时只需要查询:
rt[l-1],rt[r]
这里插入时当前位置是x,上一个位置是父亲节点(pre是fa)
查询时可以参照树上差分问题,x+y-pre-pre_fa,对比一般线段树是r-(l-1)
#include <bits/stdc++.h>
#define inf 0x7fffffff
#define ll long long
#define int long long
//#define double long double
//#define double long long
#define re register int
//#define void inline void
#define eps 1e-8
//#define mod 1e9+7
#define ls(p) p<<1
#define rs(p) p<<1|1
#define pi acos(-1.0)
#define pb push_back
#define mk make_pair
#define P pair < int , int >
using namespace std;
const int mod=9901;
//const int inf=1e18;
const int M=1e8;
const int N=2e7+5;//??????.???? 4e8
struct ndoe
{
int ver,next;
}e[N];
int n,m,ans;
int tot=1,head[N];
int a[N],rt[N],b[N];
int f[100005][65],t,d[N];
int cnt,nn;
int fa[N];
struct tree
{
int l,r,sum;
}tr[N];
void add(int x,int y)
{
e[++tot].ver=y;
e[tot].next=head[x];
head[x]=tot;
}
void addedge(int x,int y)
{
add(x,y);add(y,x);
}
void insert(int &p,int pre,int l,int r,int pos)
{
tr[++cnt]=tr[pre];
p=cnt;
tr[cnt].sum++;
if(l==r) return;
int mid=(l+r)>>1;
if(pos<=mid) insert(tr[p].l,tr[pre].l,l,mid,pos);
else insert(tr[p].r,tr[pre].r,mid+1,r,pos);
}
int ask(int x,int y,int pre,int pre_fa,int l,int r,int k)
{
if(l==r) return l;
int mid=(l+r)>>1;
int t=tr[tr[x].l].sum+tr[tr[y].l].sum-tr[tr[pre].l].sum-tr[tr[pre_fa].l].sum;
if(k<=t) return ask(tr[x].l,tr[y].l,tr[pre].l,tr[pre_fa].l,l,mid,k);
else return ask(tr[x].r,tr[y].r,tr[pre].r,tr[pre_fa].r,mid+1,r,k-t);
}
void bfs()
{
queue < int > q;
q.push(1);
d[1]=1;
insert(rt[1],rt[0],1,n,a[1]);
while(q.size())
{
int x=q.front();q.pop();
for(re i=head[x];i;i=e[i].next)
{
int y=e[i].ver;
if(d[y]) continue;
insert(rt[y],rt[x],1,n,a[y]);
d[y]=d[x]+1;
f[y][0]=x;
fa[y]=x;
for(re j=1;j<=t;j++) f[y][j]=f[f[y][j-1]][j-1];
q.push(y);
}
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(re i=t;i>=0;i--) if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(re i=t;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
void solve()
{
cin>>n>>m;
t=(int)(log(n)/log(2))+1;
for(re i=1;i<=n;i++) scanf("%lld",&a[i]),b[i]=a[i];
for(re i=1;i<n;i++)
{
int x,y;
scanf("%lld%lld",&x,&y);
addedge(x,y);
}
sort(b+1,b+n+1);
nn=unique(b+1,b+n+1)-(b+1);
for(re i=1;i<=n;i++) a[i]=lower_bound(b+1,b+n+1,a[i])-b;
bfs();
int ans=0;
while(m --)
{
int u,v,k;
scanf("%lld%lld%lld",&u,&v,&k);
u^=ans;
int LCA=lca(u,v);
ans=b[ask(rt[u],rt[v],rt[LCA],rt[fa[LCA]],1,n,k)];
printf("%lld\n",ans);
}
}
signed main()
{
int T=1;
// cin>>T;
for(int index=1;index<=T;index++)
{
solve();
// puts("");
}
return 0;
}
/*
*/