题目链接
COT - Count on a tree
You are given a tree with N nodes.The tree nodes are numbered from 1 to N.Each node has an integer weight.
We will ask you to perform the following operation:
- u v k : ask for the kth minimum weight on the path from node u to node v
Input
In the first line there are two integers N and M.(N,M<=100000)
In the second line there are N integers.The ith integer denotes the weight of the ith node.
In the next N-1 lines,each line contains two integers u v,which describes an edge (u,v).
In the next M lines,each line contains three integers u v k,which means an operation asking for the kth minimum weight on the path from node u to node v.
Output
For each operation,print its result.
Example
Input:8 58 5 105 2 9 3 8 5 7 7 1 2 1 3 1 4 3 5 3 6 3 7 4 8
2 5 1
2 5 2
2 5 3
2 5 4
7 8 2
Output: 2
8
9
105
7
题解:
此题按照dfs先序遍历进行对每个节点进行建立树,对于结点u,其相邻结点是其父节点,建完树后,进行查询操作时,假设是对于结点x和结点y,他们的公共祖先为f,那么求相应区间的数的个数时要注意对f结点的处理,可以判断a[f]的位置以判断sum是否加上f结点,详见代码。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<vector>
#include<queue>
#include<stack>
using namespace std;
#define rep(i,a,n) for (int i=a;i<n;i++)
#define per(i,a,n) for (int i=n-1;i>=a;i--)
#define pb push_back
#define fi first
#define se second
typedef vector<int> VI;
typedef long long ll;
typedef pair<int,int> PII;
const int inf=0x3fffffff;
const ll mod=1000000007;
const int maxn=300000+100;
int head[maxn];
int n,m;
struct edge
{
int from,to,next;
}e[maxn*2]; //
int tol=0;
void add(int u,int v)
{
e[++tol].to=v,e[tol].next=head[u],head[u]=tol;
}
struct node{int l,r,sum;}T[maxn*40];
int root[maxn],a[maxn],cnt=0;
int h[maxn];
int deep[maxn],fa[maxn][19];
int cur=0,nn=0;
void init_hash()
{
rep(i,1,n+1) h[i-1]=a[i];
sort(h,h+n);
nn=unique(h,h+n)-h;
}
int getid(int x)
{
return lower_bound(h,h+nn,x)-h+1;
}
void update(int l,int r,int &x,int y,int pos,int val)
{
T[++cnt]=T[y],T[cnt].sum+=val,x=cnt;
if(l==r) return;
int mid=(l+r)/2;
if(pos<=mid) update(l,mid,T[x].l,T[y].l,pos,val);
else update(mid+1,r,T[x].r,T[y].r,pos,val);
}
int query(int x,int y,int lca,int k)
{
int pos=getid(a[lca]);
int lca_root=root[lca];
int l=1,r=n;
while(l<r)
{
int mid=(l+r)/2;
int sum=T[T[y].l].sum+T[T[x].l].sum-T[T[lca_root].l].sum*2+(pos>=l&&pos<=mid); ///
if(sum>=k)
{
x=T[x].l;
y=T[y].l;
lca_root=T[lca_root].l;
r=mid;
}
else
{
k-=sum;
x=T[x].r;
y=T[y].r;
lca_root=T[lca_root].r;
l=mid+1;
}
}
return l;
}
void dfs(int u,int f)
{
for(int i=1;i<19;i++)
{
if(deep[u]<(1<<i)) break;
fa[u][i]=fa[fa[u][i-1]][i-1];
}
update(1,n,root[u],root[f],getid(a[u]),1);
for(int i=head[u];i;i=e[i].next)
{
int v=e[i].to;
if(v==f) continue;
deep[v]=deep[u]+1;
fa[v][0]=u;
dfs(v,u);
}
}
int lca(int x,int y)
{
if(deep[x]<deep[y]) swap(x,y);
int d=deep[x]-deep[y];
rep(i,0,19)
if(d&(1<<i)) x=fa[x][i];
if(x==y) return x;
per(i,0,19)
if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
int solve(int x,int y,int k)
{
int f=lca(x,y);
int ans=h[query(root[x],root[y],f,k)-1];
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
rep(i,1,n+1) scanf("%d",&a[i]);
init_hash();
rep(i,1,n)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v),add(v,u);
}
dfs(1,0);
while(m--)
{
int u,v,k;
scanf("%d%d%d",&u,&v,&k);
printf("%d\n",solve(u,v,k));
}
return 0;
}