题目链接:https://nanti.jisuanke.com/t/38229
题意:给一个n个节点的树,m个询问,每次询问u到v中不大于k的边有几条
解法:LCA+主席树
网上比较流行的解法是将原树树链剖分后在原树上建主席树,事实上,本题树链剖分的目的是为了求lca,既然是为了求lca,就不需要写较为繁琐的树链剖分,我直接写了tarjan的倍增求lca的方法,然后在树上建主席树。
下一步的问题就是如何实现这个询问了,首先将所有边权离散后,用upper_bound查找第一个小于k的位置,然后问题就转化为区间第k大问题了。主席树就是为了解决这个问题的,那么如何在树上建主席树呢?主席树的insert的本质是插入当前点的后继点,那么树上的如何定义一个点的后继点呢,很自然的想到了父子结点关系,子节点即是父节点的后继点,然后就可以insert子节点来建立主席树了。
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
struct edge
{
int u,v,w;
}e[maxn<<1];
int n;
int Next[maxn<<1];
int head[maxn<<1];
int dp[maxn][20];
int dep[maxn];
int cnt;
int tot;
void add(int u,int v,int w)
{
e[cnt].u=u;
e[cnt].v=v;
e[cnt].w=w;
Next[cnt]=head[u];
head[u]=cnt++;
}
vector<int>vis;
struct Node
{
int l,r,val;
Node(){}
Node(int a,int b,int c)
{
l=a;
r=b;
val=c;
}
}node[maxn*50];
int root[maxn];
int arr[maxn];
int gettid(int num)
{
return lower_bound(vis.begin(),vis.end(),num)-vis.begin()+1;
}
int insert(int num,int l,int r,int val)
{
int oo=++tot;
node[oo].l=node[num].l;
node[oo].r=node[num].r;
node[oo].val=node[num].val+1;
if(l==r)return oo;
int m=(l+r)>>1;
if(m>=val)
{
node[oo].l=insert(node[oo].l,l,m,val);
}
else
{
node[oo].r=insert(node[oo].r,m+1,r,val);
}
return oo;
}
int query(int u,int v,int l,int r,int k)
{
if(r<=k) return node[v].val-node[u].val;
int mid=(l+r)/2;
int ls=node[node[v].l].val-node[node[u].l].val;
if(mid>=k) return query(node[u].l,node[v].l,l,mid,k);
else return ls+query(node[u].r,node[v].r,mid+1,r,k);
}
void dfs(int num,int fa)
{
dep[num]=dep[fa]+1;
dp[num][0]=fa;
for(int i=1;(1<<i)<=dep[num];i++)
{
dp[num][i]=dp[dp[num][i-1]][i-1];
}
for(int i=head[num];i!=-1;i=Next[i])
{
int v=e[i].v;
int w=e[i].w;
if(v==fa)continue;
root[v]=insert(root[num],1,vis.size(),gettid(w));
dfs(v,num);
}
}
int lca(int a,int b)
{
if(dep[a]<dep[b])swap(a,b);
for(int i=19;i>=0;i--)
{
if(dep[a]-(1<<i)>=dep[b])a=dp[a][i];
}
if(a==b)return a;
for(int i=19;i>=0;i--)
{
if(dp[a][i]!=dp[b][i])
{
a=dp[a][i];
b=dp[b][i];
}
}
return dp[a][0];
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
memset(head,-1,sizeof(head));
cnt=0;
tot=0;
int m;
cin>>n>>m;
for(int i=1;i<n;i++)
{
int a,b,c;
cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
vis.push_back(c);
}
sort(vis.begin(),vis.end());
vis.erase(unique(vis.begin(),vis.end()),vis.end());
dfs(1,0);
while(m--)
{
int x,y,k;
cin>>x>>y>>k;
int LCA=lca(x,y);
int h=upper_bound(vis.begin(),vis.end(),k)-vis.begin();
if(h!=0){
int ans=query(root[LCA],root[x],1,vis.size(),h)+query(root[LCA],root[y],1,vis.size(),h);
cout<<ans<<'\n';
}
else{
cout<<"0"<<'\n';
}
}
return 0;
}