第一步:离散化
即:把节点的点权换成它在所有点权中的排名(它是第几小的)
将存储点权的数组复制一份之后排序,去重,然后将原先的每个点权在去重后的数组里进行二分查找,就可以得到它的排名。
第二步:建主席树
每个节点维护它到根的路径上的权值线段树,所以每个节点可以利用它的父节点更新,所以将整棵树dfs一遍,在此过程中建树。
第三步:求解
用u点的主席树+v点的主席树-lca(u,v)的主席树-lca(u,v)父节点的主席树,在这样产生的主席树上查找第k小的排名,最后输出它原来的点权。
代码如下
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cstdlib>
using namespace std;
const int e=1e5+5;
struct point
{
int l,r,w;
}c[e*30];
int fa[e][21],a[e],h[e],tot,num,deep[e],n,m,rt[e],xxx,val[e];
int next[e<<1],head[e],go[e<<1];
inline int read()
{
char ch;
int res=0;
bool f=false;
while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');
if(ch=='-')f=true;
else res=ch-48;
while(ch=getchar(),ch>='0'&&ch<='9')
res=(res<<3)+(res<<1)+ch-48;
return f? -res:res;
}
inline void insert(int y,int &x,int l,int r,int p)
{
c[x=++num]=c[y];
c[x].w++;
if(l==r)return;
int mid=l+r>>1;
if(p<=mid)insert(c[y].l,c[x].l,l,mid,p);
else insert(c[y].r,c[x].r,mid+1,r,p);
c[x].w=c[c[x].l].w+c[c[x].r].w;
}
inline int query(int x,int y,int z,int d,int l,int r,int k)
{
if(l==r)return l;
int ret=c[c[x].l].w+c[c[y].l].w-c[c[z].l].w-c[c[d].l].w;
int mid=l+r>>1;
if(k<=ret)
return query(c[x].l,c[y].l,c[z].l,c[d].l,l,mid,k);
else return query(c[x].r,c[y].r,c[z].r,c[d].r,mid+1,r,k-ret);
}
inline void add(int x,int y)
{
next[++tot]=head[x];
head[x]=tot;
go[tot]=y;
next[++tot]=head[y];
head[y]=tot;
go[tot]=x;
}
inline void dfs2(int u)
{
insert(rt[fa[u][0]],rt[u],1,n,val[u]);
for(int i=head[u];i;i=next[i])
{
int v=go[i];
if(v==fa[u][0])continue;
dfs2(v);
}
}
inline void dfs(int u,int father)
{
deep[u]=deep[father]+1;
for(int i=0;i<=19;i++)
fa[u][i+1]=fa[fa[u][i]][i];
for(int i=head[u];i;i=next[i])
{
int v=go[i];
if(v==father)continue;
fa[v][0]=u;
dfs(v,u);
}
}
inline int lca(int x,int y)
{
if(deep[x]<deep[y])swap(x,y);
for(int i=19;i>=0;i--)
{
if(deep[fa[x][i]]>=deep[y])
x=fa[x][i];
if(x==y)return x;
}
for(int i=19;i>=0;i--)
{
if(fa[x][i]!=fa[y][i])
{
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
inline int find(int x)
{
int l=1,r=xxx,mid;
while(l<=r)
{
mid=l+r>>1;
if(x>h[mid])
l=mid+1;
else r=mid-1;
}
return l;
}
int main()
{
int i,j,u,v,k;
n=read();
m=read();
for(i=1;i<=n;i++)
{
val[i]=read();
a[i]=val[i];
}
for(i=1;i<n;i++)
{
u=read();
v=read();
add(u,v);
}
sort(a+1,a+n+1);
h[1]=a[1];
xxx=1;
for(i=2;i<=n;i++)
if(a[i]!=a[i-1])h[++xxx]=a[i];
for(i=1;i<=n;i++)
val[i]=find(val[i]);
dfs(1,0);
int ans=0;
dfs2(1);
for(i=1;i<=m;i++)
{
u=read();
v=read();
k=read();
u^=ans;
int z=lca(u,v);
int last=query(rt[u],rt[v],rt[z],rt[fa[z][0]],1,n,k);
ans=h[last];
printf("%d",ans);
if(i!=m)putchar('\n');
}
return 0;
}