问题描述
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。最后一个询问不输出换行符
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
Hint
N,M<=100000
分析:
在dfs的过程中建立主席树(每个节点上一版本的树为父节点建立的树)
r(i)表示树i
假设要求的两个点为x、y
x、y的lca为z,
平常lca求x、y的距离的时候是dist(x)+dist(y)-2dist(z)
则x到y路径上的权值线段树为r(x)+r(y)-r(z)-r(fa(z))
为什么最后一个是fa(z)不是z呢
因为这题统计的是点权,z这个点必须保留一个,所以最后一个的是fa(z),即z的父节点
最后第k大就是基本操作了
还有一个地方很关键,就是根节点的父节点不能是自己,(否则无限RE),
平常lca都是直接把根节点的父节点设置为自己,但是这题这样设置会出错(ask的时候如果lca为根会多减)
要设为0(r(0)是空树,减了不影响)
code:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<cmath>
#include<algorithm>
using namespace std;
const int maxm=1e5+5;
struct Node{
int lc,rc,cnt;
}a[maxm*40];
int r[maxm],cnt;
int b[maxm],xx[maxm],num;
int head[maxm],nt[maxm<<1],to[maxm<<1],tot;
int f[maxm][20],d[maxm],maxd;
int lastans;
int n,m;
void init(){
memset(head,-1,sizeof head);
}
void add(int x,int y){
tot++;nt[tot]=head[x];head[x]=tot;to[tot]=y;
}
int build(int l,int r){
int k=cnt++;
a[k].cnt=0;
if(l!=r){
int mid=(l+r)/2;
a[k].lc=build(l,mid);
a[k].rc=build(mid+1,r);
}
return k;
}
int update(int x,int val,int l,int r,int last){
int k=cnt++;
a[k]=a[last];
a[k].cnt+=val;
if(l!=r){
int mid=(l+r)/2;
if(x<=mid)a[k].lc=update(x,val,l,mid,a[last].lc);
else a[k].rc=update(x,val,mid+1,r,a[last].rc);
}
return k;
}
void dfs(int x,int dep,int pre){
d[x]=dep;
r[x]=update(b[x],1,1,num,r[pre]);
for(int i=head[x];i!=-1;i=nt[i]){
int v=to[i];
if(!d[v]){
d[v]=d[x]+1;
f[v][0]=x;
dfs(v,dep+1,x);
}
}
}
int lca(int a,int b){
if(d[a]<d[b])swap(a,b);
for(int i=maxd;i>=0;i--){
if(d[f[a][i]]>=d[b]){
a=f[a][i];
}
}
if(a==b)return a;
for(int i=maxd;i>=0;i--){
if(f[a][i]!=f[b][i]){
a=f[a][i],b=f[b][i];
}
}
return f[a][0];
}
int ask(int k,int x,int y,int fa,int ffa,int l,int r){
if(l==r)return l;
int mid=(l+r)/2;
int res=a[a[x].lc].cnt+a[a[y].lc].cnt-a[a[fa].lc].cnt-a[a[ffa].lc].cnt;
if(res>=k)return ask(k,a[x].lc,a[y].lc,a[fa].lc,a[ffa].lc,l,mid);
return ask(k-res,a[x].rc,a[y].rc,a[fa].rc,a[ffa].rc,mid+1,r);
}
signed main(){
init();
scanf("%d%d",&n,&m);
maxd=(int)(log(n)/log(2))+1;
for(int i=1;i<=n;i++){
scanf("%d",&b[i]);
xx[i]=b[i];
}
sort(xx+1,xx+1+n);
num=unique(xx+1,xx+1+n)-xx-1;
for(int i=1;i<=n;i++){//离散化
b[i]=lower_bound(xx+1,xx+1+num,b[i])-xx;
}
for(int i=1;i<=n-1;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
r[0]=build(1,num);
f[1][0]=0;这个地方是00000000000000000000000000000000000000
dfs(1,1,0);
for(int j=1;j<=maxd;j++){
for(int i=1;i<=n;i++){
f[i][j]=f[f[i][j-1]][j-1];
}
}
while(m--){
int x,y,k;
scanf("%d%d%d",&x,&y,&k);
x^=lastans;
int t=lca(x,y);
int ans=ask(k,r[x],r[y],r[t],r[f[t][0]],1,num);
lastans=xx[ans];
printf("%d\n",lastans);
}
return 0;
}