Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
求区间第k小的树上版本……其实也差不多,每一个节点一棵树,表示该节点到根的信息,最后查询的时候求一求lca,乱搞一搞就可以了……
代码:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define maxn (int)(1e5+5)
using namespace std;
struct tree1{int fa,dep,par[22];}a[maxn];
struct tree2{int lc,rc,val;}tr[maxn*20];
struct edge{int y,next;}b[maxn*2];
struct array{int x,id;}v[maxn];
bool cmp(array u,array v) {return u.x<v.x;}
int lb=0,last[maxn];
int n,m,rank[maxn],len=0,root[maxn];
void ins(int x,int y)
{
int t=++lb;
b[t].y=y;b[t].next=last[x];last[x]=t;
}
void build(int l,int r)
{
int t=++len;tr[t].val=0;
if(l<r)
{
int mid=l+r>>1;
tr[t].lc=len+1;build(l,mid);
tr[t].rc=len+1;build(mid+1,r);
}
}
void update(int last,int l,int r,int k)
{
int t=++len;
tr[t]=tr[last];
if(l==r) {tr[t].val++;return;}
int mid=l+r>>1;
if(k<=mid) tr[t].lc=len+1,update(tr[last].lc,l,mid,k);
else tr[t].rc=len+1,update(tr[last].rc,mid+1,r,k);
tr[t].val=tr[tr[t].lc].val+tr[tr[t].rc].val;
}
void dfs(int x,int fa)
{
a[x].par[0]=fa;a[x].dep=a[fa].dep+1;
for(int i=1;i<=20;i++)
if(a[x].dep>=(1<<i)) a[x].par[i]=a[a[x].par[i-1]].par[i-1];
else break;
root[x]=len+1;update(root[fa],1,n,rank[x]);
for(int i=last[x];i!=-1;i=b[i].next) if(b[i].y!=fa) dfs(b[i].y,x);
}
int get_lca(int x,int y)
{
if(a[x].dep<a[y].dep) {int t=x;x=y;y=t;}
for(int i=20;i>=0;i--) if((1<<i)<=a[x].dep-a[y].dep) x=a[x].par[i];
if(x==y) return x;
for(int i=20;i>=0;i--)
if((1<<i)<=a[x].dep && a[x].par[i]!=a[y].par[i]) x=a[x].par[i],y=a[y].par[i];
return a[x].par[0];
}
int query(int l,int r,int k,int rt1,int rt2,int rt3,int rt4)
{
if(l==r) return v[l].x;
int mid=l+r>>1,tt=tr[tr[rt1].lc].val+tr[tr[rt2].lc].val-tr[tr[rt3].lc].val-tr[tr[rt4].lc].val;
if(k<=tt) return query(l,mid,k,tr[rt1].lc,tr[rt2].lc,tr[rt3].lc,tr[rt4].lc);
else return query(mid+1,r,k-tt,tr[rt1].rc,tr[rt2].rc,tr[rt3].rc,tr[rt4].rc);
}
int main()
{
int lastans=0;
memset(last,-1,sizeof(last));
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&v[i].x);
v[i].id=i;
}
sort(v+1,v+1+n,cmp);for(int i=1;i<=n;i++) rank[v[i].id]=i;
//for(int i=1;i<=n;i++) printf("%d ",rank[i]);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
ins(x,y);ins(y,x);
}
root[0]=1;build(1,n);
a[0].dep=0;dfs(1,0);
for(int i=1;i<=m;i++)
{
int x,y,k,lca,fa;
scanf("%d%d%d",&x,&y,&k);
x^=lastans;
lca=get_lca(x,y);fa=a[lca].par[0];
//printf("%d %d %d %d %d\n",x,y,k,lca,fa);
lastans=query(1,n,k,root[x],root[y],root[lca],root[fa]);
printf("%d",lastans);if(i!=m) printf("\n");
}
}