题目链接:点击打开链接
题目大意:给你一颗n个结点的树,每个结点都有一个权值。
为从结点u到结点v这条路径上所有的权值第k大是多少
思路:主要思路是 树上第k大,LCA+主席树
为什么要用到LCA?
因为u到v的路径必定是从u走到他们的LCA(u,v)点,然后从LCA(u,v)点走到v。
所以我们需要提前求出这棵树的LCA
主席树怎么建树?
每一棵线段树都是每个结点继承其父节点的线段树加上自己的结点。
如此对于一条路径就可以统计从其祖先结点到其子节点这条路径上增加了哪些数
在主席树路径上找第k大时,即统计LCA(u,v)到u和v两条路径上总共出现了【l,r】区间内的数字有多少个,大于k往左子树,小于k往右子树,和普通的主席树找第k大相同
注意一下,因为lca(u,v)是u,v的交点,统计时候如果被统计两次,需要减去一次。
代码参考kuangbin模板:
#include <stdio.h>
#include <string.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <time.h>
using namespace std;
//主席树部分 *****************8
const int MAXN = 400010;
const int M = MAXN * 40;
int n,q,m,TOT;
int a[MAXN], t[MAXN];
int T[M], lson[M], rson[M], c[M];
void Init_hashs()
{
for(int i = 1; i <= n;i++)
t[i] = a[i];
sort(t+1,t+1+n);
m = unique(t+1,t+n+1)-t-1;
}
int build(int l,int r)
{
int root = TOT++;
c[root] = 0;
if(l != r)
{
int mid = (l+r)>>1;
lson[root] = build(l,mid);
rson[root] = build(mid+1,r);
}
return root;
}
int hashs(int x)
{
return lower_bound(t+1,t+1+m,x) - t;
}
int update(int root,int pos,int val)
{
int newroot = TOT++, tmp = newroot;
c[newroot] = c[root] + val;
int l = 1, r = m;
while( l < r)
{
int mid = (l+r)>>1;
if(pos <= mid)
{
lson[newroot] = TOT++; rson[newroot] = rson[root];
newroot = lson[newroot]; root = lson[root];
r = mid;
}
else
{
rson[newroot] = TOT++; lson[newroot] = lson[root];
newroot = rson[newroot]; root = rson[root];
l = mid+1;
}
c[newroot] = c[root] + val;
}
return tmp;
}
int query(int left_root,int right_root,int LCA,int k)
{
int lca_root = T[LCA];
int pos = hashs(a[LCA]);
int l = 1, r = m;
while(l < r)
{
int mid = (l+r)>>1;
int tmp = c[lson[left_root]] + c[lson[right_root]] - 2*c[lson[lca_root]] + (pos >= l && pos <= mid);
//统计这条路径上权值在离散化数组t中坐标在[l,mid]之间的个数
//后面减去两倍的他们的lca和后面那个判断 主要是为了如果他们lca被减去两边的话 需要补加一个
if(tmp >= k)
{
left_root = lson[left_root];
right_root = lson[right_root];
lca_root = lson[lca_root];
r = mid;
}
else
{
k -= tmp;
left_root = rson[left_root];
right_root = rson[right_root];
lca_root = rson[lca_root];
l = mid + 1;
}
}
return l;
}
/************LCA模板部分************/
int rmq[2*MAXN];//建立RMQ的数组
//ST算法,里面含有初始化init(n)和query(s,t)函数
//点的编号从1开始,1-n.返回最小值的下标
struct ST
{
int mm[2*MAXN];//mm[i]表示i的最高位,mm[1]=0,mm[2]=1,mm[3]=1,mm[4]=2
int dp[MAXN*2][20];
void init(int n)
{
mm[0]=-1;
for(int i=1;i<=n;i++)
{
mm[i]=((i&(i-1))==0?mm[i-1]+1:mm[i-1]);
dp[i][0]=i;
}
for(int j=1;j<=mm[n];j++)
for(int i=1;i+(1<<j)-1<=n;i++)
dp[i][j]=rmq[dp[i][j-1]]<rmq[dp[i+(1<<(j-1))][j-1]]?dp[i][j-1]:dp[i+(1<<(j-1))][j-1];
}
int query(int a,int b)//查询a到b间最小值的下标
{
if(a>b)swap(a,b);
int k=mm[b-a+1];
return rmq[dp[a][k]]<rmq[dp[b-(1<<k)+1][k]]?dp[a][k]:dp[b-(1<<k)+1][k];
}
};
//边的结构体定义
struct Node
{
int to,next;
};
/* ******************************************
LCA转化为RMQ的问题
MAXN为最大结点数。ST的数组 和 F,edge要设置为2*MAXN
F是欧拉序列,rmq是深度序列,P是某点在F中第一次出现的下标
*********************************************/
struct LCA2RMQ
{
int n;//结点个数
Node edge[2*MAXN];//树的边,因为是建无向边,所以是两倍
int tol;//边的计数
int head[MAXN];//头结点
bool vis[MAXN];//访问标记
int F[2*MAXN];//F是欧拉序列,就是DFS遍历的顺序
int P[MAXN];//某点在F中第一次出现的位置
int cnt;
ST st;
void init(int n)//n为所以点的总个数,可以从0开始,也可以从1开始
{
this->n=n;
tol=0;
memset(head,-1,sizeof(head));
}
void addedge(int a,int b)//加边
{
edge[tol].to=b;
edge[tol].next=head[a];
head[a]=tol++;
}
int query(int a,int b)//传入两个节点,返回他们的LCA编号
{
return F[st.query(P[a],P[b])];
}
void dfs(int a,int lev)
{
vis[a]=true;
++cnt;//先加,保证F序列和rmq序列从1开始
F[cnt]=a;//欧拉序列,编号从1开始,共2*n-1个元素
rmq[cnt]=lev;//rmq数组是深度序列
P[a]=cnt;
for(int i=head[a];i!=-1;i=edge[i].next)
{
int v=edge[i].to;
if(vis[v])continue;
dfs(v,lev+1);
++cnt;
F[cnt]=a;
rmq[cnt]=lev;
}
}
void solve(int root)
{
memset(vis,false,sizeof(vis));
cnt=0;
dfs(root,0);
st.init(2*n-1);
}
}lca;
/***************主席树的建树*****************/
void dfs_build(int u,int pre)
{
int pos = hashs(a[u]);
T[u] = update(T[pre],pos,1);
for(int i = lca.head[u]; i != -1;i = lca.edge[i].next)
{
int v =lca.edge[i].to;
if(v == pre)continue;
dfs_build(v,u);
}
}
int main()
{
while(scanf("%d%d",&n,&q) == 2)
{
for(int i = 1;i <= n;i++)
scanf("%d",&a[i]);
Init_hashs();
lca.init(n);
TOT = 0;
int u,v;
for(int i = 1;i < n;i++)
{
scanf("%d%d",&u,&v);
lca.addedge(u,v);
lca.addedge(v,u);
}
lca.solve(1);
T[n+1] = build(1,m);
dfs_build(1,n+1);
int k;
while(q--)
{
scanf("%d%d%d",&u,&v,&k);
printf("%d\n",t[query(T[u],T[v],lca.query(u,v),k)]);
}
return 0;
}
return 0;
}