//count the tree
/*
思路:每个点建立一颗线段树(增量建立),以遍历的时间为序,充分利用上一颗线段树的信息,在这题上一颗线段树就是父节点的线段树,因为我们每次更新的信息只有一个节点,一个节点被更新了,那么它的所有祖先节点也要相应的被更新,又因为在线段树中一个节点的祖先节点数不会超过(logN)个,所以这颗线段树和上一颗线段树大部分节点是一样的,只有刚刚说的logN个节点被改变了,所以我们只要保存这改变的logN个节点就可以了,因此整个程序的空间复杂度就是O(NlogN)。
在这题里面,因为要查找的两个点a,b的这两颗线段树的信息,都是从lca(a,b)这个点里面的信息改变来的,所以其中的信息重复了一次;
所以在后面判断第k个数是在左子树上还是在右子树上时要减掉lca(a,b)这个点之前的信息,接下来的处理就和求线性区间第k大
一样的思路了。
*/
实现代码:
#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define FN 2010000
#define N 101000
int tot,T[N],L[FN],R[FN],C[FN],pa[N][20],n;
int build(int l,int r)
{
int mid,rt=++tot;
if(l==r) return rt;
mid=(l+r)>>1;
L[rt]=build(l,mid);
R[rt]=build(mid+1,r);
return rt;
}
int update(int rt,int p,int d)
{
int newrt=++tot,l=1,r=n,mid,root=newrt;
C[newrt]=C[rt]+d;
while(r>l){
mid=(l+r)>>1;
if(p<=mid){
L[newrt]=++tot;R[newrt]=R[rt];
rt=L[rt];newrt=L[newrt];r=mid;
}
else{
R[newrt]=++tot;L[newrt]=L[rt];
rt=R[rt];newrt=R[newrt];l=mid+1;
}
C[newrt]=C[rt]+d;
}
return root;
}
struct node{ int id,v,sd; }A[N];
int query(int t1,int t2,int lca,int k)
{
int l=1,r=n,mid,p=T[pa[lca][0]],t;lca=T[lca];
while(r>l){
mid=(l+r)>>1;
t=C[L[t1]]+C[L[t2]]-C[L[lca]]-C[L[p]];
if(t>=k){
t1=L[t1];t2=L[t2];
lca=L[lca];p=L[p];r=mid;
}
else{
k-=t;t1=R[t1];t2=R[t2];
lca=R[lca];p=R[p];l=mid+1;
}
}return A[l].v;
}
struct E{ int t,next; }edge[2*N];
int ant,head[N],H[N];
void add(int a,int b)
{
edge[ant].t=b;
edge[ant].next=head[a];
head[a]=ant++;
}
void dfs(int rt,int p,int dep)
{
int i;
pa[rt][0]=p;H[rt]=dep;
T[rt]=update(T[p],A[rt].sd,1);
for(i=head[rt];i!=-1;i=edge[i].next)
{
if(edge[i].t==p) continue;
dfs(edge[i].t,rt,dep+1);
}
}
bool cmp1(node a,node b){ return a.v<b.v; }
bool cmp2(node a,node b){ return a.id<b.id;}
int B[N];
int Lca(int x,int y)
{
int k;
if(x==y) return x;
if(H[x]<H[y]) swap(x,y);
for(k=B[H[x]-H[y]];k>=0;--k)
if(H[x]-H[y]>=(1<<k))
x=pa[x][k];
if(x==y) return x;
for(k=B[H[x]];k>=0;--k)
{
if(pa[x][k]&&pa[x][k]!=pa[y][k])
x=pa[x][k],y=pa[y][k];
}
return pa[x][0];
}
int main()
{
int m,i,a,b,k,lca;
for(i=1;i<=N;i++)
{
B[i]=0;while(i>=(1<<B[i])) B[i]++;
}
while(scanf("%d%d",&n,&m)!=EOF)
{
tot=ant=0;
memset(head,-1,sizeof(head));
for(i=1;i<=n;i++){
A[i].id=i;
scanf("%d",&A[i].v);
}
T[0]=build(1,n);
for(i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
add(a,b);add(b,a);
}
sort(A+1,A+n+1,cmp1);
for(i=1;i<=n;i++) A[i].sd=i;
sort(A+1,A+n+1,cmp2);
dfs(1,0,0);
sort(A+1,A+n+1,cmp1);
for(k=1;k<20;k++)
for(i=1;i<=n;i++)
if(pa[i][k-1]) pa[i][k]=pa[pa[i][k-1]][k-1];
while(m--)
{
scanf("%d%d%d",&a,&b,&k);lca=Lca(a,b);
printf("%d\n",query(T[a],T[b],lca,k));
}
}
return 0;
}