题意描述:
给出一个N节点的树,每个点有点权
M个询问 每次询问 U->V路径上的第k值
题目分析:
区间第K值主席树。
树上主席树维护点权得出的当前答案应该是 sum[u]+sum[v]-sum[lca]-sum[fa[lca]]
题目链接:
AC代码:
#include <iostream>
#include <cstdio>
#include <algorithm>
const int maxm=2e5+1000;
int cnt,sz,tot,n,m;
int head[maxm],net[maxm*2],to[maxm*2],cost[maxm*2];
int root[maxm];
int fa[maxm],deep[maxm],top[maxm],siz[maxm],son[maxm];
int _hash[maxm];
struct node{
int u,v,k;
}ask[maxm];
inline void addedge(int u,int v,int c){to[++cnt]=v,cost[cnt]=c,net[cnt]=head[u],head[u]=cnt;}
struct mtree{
int lson[maxm*20],rson[maxm*20];
int sum[maxm*20];
void insert(int &now,int pre,int l,int r,int num)
{
now=++sz;
sum[now]=sum[pre]+1;
lson[now]=lson[pre],rson[now]=rson[pre];
//printf("%d %d %d\n",l,r,num);
if(l>=r) return;
int mid=(l+r)>>1;
num<=mid?insert(lson[now],lson[pre],l,mid,num):insert(rson[now],rson[pre],mid+1,r,num);
}
int ask_kth(int u,int v,int lca,int l,int r,int kth)
{
if(r<=kth) return sum[u]+sum[v]-sum[lca]-sum[lca];
if(l>kth) return 0;
int mid=(l+r)>>1;
return ask_kth(lson[u],lson[v],lson[lca],l,mid,kth)+ask_kth(rson[u],rson[v],rson[lca],mid+1,r,kth);
}
}st;
void dfs1(int now,int fax,int val)
{
st.insert(root[now],root[fax],1,tot+1,val);
fa[now]=fax,deep[now]=deep[fax]+1;
siz[now]=1,son[now]=0;
int maxson=-1;
for(int i=head[now];i;i=net[i])
if(to[i]!=fax)
{
dfs1(to[i],now,std::lower_bound(_hash+1,_hash+tot+1,cost[i])-_hash);
siz[now]+=siz[to[i]];
if(siz[to[i]]>maxson) maxson=siz[to[i]],son[now]=to[i];
}
}
void dfs2(int now,int topx)
{
top[now]=topx;
if(son[now]) dfs2(son[now],topx);
for(int i=head[now];i;i=net[i])
if(!top[to[i]]) dfs2(to[i],to[i]);
}
inline int LCA(int u,int v)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) std::swap(u,v);
u=fa[top[u]];
}
return deep[u]<deep[v]?u:v;
}
int main()
{
//freopen("1.txt","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1,u,v,c;i<n;i++)
{
scanf("%d%d%d",&u,&v,&c);
_hash[++tot]=c;
addedge(u,v,c);
addedge(v,u,c);
}
for(int i=1,u,v,k;i<=m;i++)
{
scanf("%d%d%d",&u,&v,&k);
ask[i]=(node){u,v,k};
_hash[++tot]=k;
}
std::sort(_hash+1,_hash+tot+1);
tot=std::unique(_hash+1,_hash+tot+1)-_hash-1;
dfs1(1,0,tot+1);
dfs2(1,1);
for(int i=1,u,v,k;i<=m;i++)
{
u=ask[i].u,v=ask[i].v,k=std::lower_bound(_hash+1,_hash+tot+1,ask[i].k)-_hash;
//printf("%d\n",k);
printf("%d\n",st.ask_kth(root[u],root[v],root[LCA(u,v)],1,tot+1,k));
}
return 0;
}