Description
给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。
Input
第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。
Output
M行,表示每个询问的答案。
Sample Input
8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2
Sample Output
2
8
9
105
7
HINT
HINT:
N,M<=100000
暴力自重。。。
题解
建立一棵可持久化权值线段树,对于树上的每一个点,以它的父亲节点为上一个版本,根节点的父亲节点是0
询问u与v节点之间第K小的点权,就是询问 (root,u) + (root,v) - (root,lca(u,v)) - (root,fa(lca(u,v))) 第K小的点权,在权值线段树上二分即可。
代码
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define ll long long
#define N 100011
#define M 2000005
struct seg{int v;int p;}a[N],hash[N];
int n,m,tot,cnt,last,color,rt[N];
int ret[M],Next[M],Head[N];
int ls[M],rs[M],sum[M];
int dep[N],fa[N][18];
using namespace std;
bool cmp(seg a,seg b)
{
return a.v<b.v;
}
int read()
{
int x=0,f=1;char ch=getchar();
while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void ins(int u,int v)
{
tot++;
ret[tot]=v;
Next[tot]=Head[u];
Head[u]=tot;
}
void change(int k,int &p,int l,int r,int x)
{
if (!p) p=++cnt;
sum[p]=sum[k]+1;
if (l==r) return;
int mid=(l+r)/2;
if (x<=mid) {rs[p]=rs[k];change(ls[k],ls[p],l,mid,x);}
else {ls[p]=ls[k];change(rs[k],rs[p],mid+1,r,x);}
}
void dfs(int u,int f)
{
fa[u][0]=f;
dep[u]=dep[f]+1;
change(rt[f],rt[u],1,color,hash[u].v);
for (int i=Head[u];i;i=Next[i])
{
int v=ret[i];
if (v!=f) dfs(v,u);
}
}
void build()
{
for (int k=1;k<=17;k++)
for (int i=1;i<=n;i++)
fa[i][k]=fa[fa[i][k-1]][k-1];
}
int lca(int a,int b)
{
if (dep[a]<dep[b]) swap(a,b);
for (int i=16;i>=0;i--)
if (dep[a]-(1<<i)>=dep[b]) a=fa[a][i];
if (a==b) return a;
for (int i=16;i>=0;i--) //i>=0 写成 i>0 RE调了半天。。
if (fa[a][i]!=fa[b][i])
{
a=fa[a][i];
b=fa[b][i];
}
return fa[a][0];
}
int query(int x,int y,int rk)
{
int a=lca(x,y),b=fa[a][0];
a=rt[a];b=rt[b];x=rt[x];y=rt[y];
int l=1,r=color;
while (l<r)
{
int mid=(l+r)/2;
int num=sum[ls[x]]+sum[ls[y]]-sum[ls[a]]-sum[ls[b]];
if (num>=rk){
r=mid;
x=ls[x];
y=ls[y];
a=ls[a];
b=ls[b];
}
else
{
l=mid+1;
rk-=num;
x=rs[x];
y=rs[y];
a=rs[a];
b=rs[b];
}
}
return hash[l].p;
}
int main()
{
n=read();m=read();
for (int i=1;i<=n;i++)
{
a[i].v=read();
a[i].p=i;
}
sort(a+1,a+n+1,cmp);
color=1;
int pre=a[1].v;
for (int i=1;i<=n;i++)
{
if (a[i].v!=pre&&i!=1){color++;pre=a[i].v;}
hash[a[i].p].v=color;
hash[color].p=a[i].v;
}
for (int i=1;i<n;i++)
{
int u=read(),v=read();
ins(u,v);ins(v,u);
}
dep[1]=1;
dfs(1,0);
build();
ll last=0;
while (m--)
{
int x=read(),y=read(),rk=read();
x^=last;
last=query(x,y,rk);
printf("%d",last);
if (m) printf("\n");
}
return 0;
}