COT - Count on a tree
You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.
We will ask you to perform the following operation:
- u v k : ask for the kth minimum weight on the path from node u to node v
Input
In the first line there are two integers N and M. (N, M <= 100000)
In the second line there are N integers. The ith integer denotes the weight of the ith node.
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v).
In the next M lines, each line contains three integers u v k, which means an operation asking for the kth minimum weight on the path from node u to node v.
思路:
lca求最近公共祖先,主席树求第k小。
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int a[maxn],b[maxn],T[maxn];
int len,tol;
int L[maxn*20],R[maxn*20],sum[maxn*20];
int dep[maxn],f[maxn][22],lg[maxn];
vector<int>G[maxn];
int build(int l,int r)
{
int rt=++tol;
if(l<r)
{
int mid=(l+r)/2;
L[rt]=build(l,mid);
R[rt]=build(mid+1,r);
}
return rt;
}
int update(int pre,int l,int r,int k)
{
int rt=++tol;
L[rt]=L[pre];R[rt]=R[pre];sum[rt]=sum[pre]+1;
if(l<r)
{
int mid=(l+r)/2;
if(k<=b[mid]) L[rt]=update(L[pre],l,mid,k);
else R[rt]=update(R[pre],mid+1,r,k);
}
return rt;
}
void dfs(int v,int fa)
{
T[v]=update(T[fa],1,len,a[v]);
dep[v]=dep[fa]+1;
f[v][0]=fa;
for(int i=1;(1<<i)<=dep[v];i++)
f[v][i]=f[f[v][i-1]][i-1];
for(int i=0;i<G[v].size();i++)
if(G[v][i]!=fa) dfs(G[v][i],v);
}
int LCA(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y])
{
int h=lg[dep[x]-dep[y]];
x=f[x][h];
}
if(x==y)return x;
for(int i=lg[dep[x]];i>=0;i--)
{
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
}
return f[x][0];
}
int query(int x,int y,int x1,int y1,int l,int r,int k)
{
if(l==r) return l;
int ss=sum[L[x]]+sum[L[y]]-sum[L[x1]]-sum[L[y1]];
int mid=(l+r)/2;
if(ss>=k) return query(L[x],L[y],L[x1],L[y1],l,mid,k);
else return query(R[x],R[y],R[x1],R[y1],mid+1,r,k-ss);
}
int main()
{
for(int i=2;i<maxn;i++)
lg[i]=lg[i-1]+(1<<(lg[i-1]+1)==i);
int n,m;scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
b[i]=a[i];
}
sort(b+1,b+n+1);
len=unique(b+1,b+n+1)-b-1;
for(int i=1;i<n;i++)
{
int x,y;scanf("%d%d",&x,&y);
G[x].push_back(y);
G[y].push_back(x);
}
T[0]=build(1,len);
dfs(1,0);
while(m--)
{
int x,y,z;scanf("%d%d%d",&x,&y,&z);
int lca=LCA(x,y);
printf("%d\n",b[query(T[x],T[y],T[lca],T[f[lca][0]],1,len,z)]);
}
return 0;
}