2019 南昌邀请赛 Distance on the tree
主席树+lca(最近公共祖先)
最终结果就是root[u] - root[lca(u,v)] + root[v] - root[lca(u,v)] = root[u] + root[v] - 2*root[lca(u,v)]
解释的话就是root[i] 代表结点i上方的节点的数据,我们根据dfs来建树的话就可满足这一点。注意一个父亲的两个儿子之间是互不影响的,因为新建的线段树只是继承了他的父亲,他的兄弟的数据并没有被囊括进来,因此对于u,v来说两者之间重复的部分就只有公共祖先上方的节点的数据,并且这些数据是不需要的,因为两者之间的路径并不包括它们,所以root[lca(u,v)]要减两次。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;
typedef long long ll;
const int N =2e5+10;
struct node {int l,r,sum;}T[N*40];
ll v[N];
struct Edges
{
int u,v;
ll w;
}edge[N];
struct Edge
{
int to;
int w;
Edge(){}
Edge(int _to,ll _w):to(_to),w(_w){}
};
int cnt,root[N],ca[N][22],deep[N],len,n,m;
vector <Edge> e[N];
int build(int l,int r,int pre)
{
int cur = ++cnt;
T[cur] = T[pre];
if(l==r)return cur;
int mid = (l+r)/2;
T[cur].l = build(l,mid,T[pre].l);
T[cur].r = build(mid+1,r,T[pre].r);
return cur;
}
int update(int l,int r,int pre,int pos)
{
int cur = ++cnt;
T[cur] = T[pre];
T[cur].sum++;
if(l == r )return cur;
int mid = (l+r)>>1;
if(mid>=pos) T[cur].l = update(l,mid,T[pre].l,pos);
else T[cur].r = update(mid+1,r,T[pre].r,pos);
return cur;
}
void lca(int cur,int fa)
{
deep[cur] = deep[fa] + 1;
ca[cur][0] = fa;
for(int i=1;i<=20;i++)
{
if(ca[cur][i-1])
ca[cur][i] = ca[ca[cur][i-1]][i-1];
else
break;
}
for(int i=0;i<e[cur].size();i++)
{
int to = e[cur][i].to,p = e[cur][i].w;
if(to == fa)continue;
root[to] = update(1,len,root[cur],e[cur][i].w);
lca(to,cur);
}
}
int get_lca(int u,int v)
{
if(deep[v]>deep[u])
swap(u,v);
int tmp = deep[u] - deep[v];
for(int i=0;i<=20;i++)
if(tmp&(1<<i))
u = ca[u][i];
if(u == v)return u;
for(int i=20;i>=0;i--)
{
if(ca[u][i]!=ca[v][i])
{
u = ca[u][i];
v = ca[v][i];
}
}
return ca[u][0];
}
int query(int l,int r,int tl,int tr,int m,int pos)
{
if(pos >= r)
{
return T[tl].sum + T[tr].sum - T[m].sum*2;
}
int ans = 0;
int mid = (r+l)>>1;
ans += query(l,mid,T[tl].l,T[tr].l,T[m].l,pos);
if(pos > mid) ans += query(mid+1,r,T[tl].r,T[tr].r,T[m].r,pos);
return ans;
}
void init()
{
for(int i=1;i<=n-1;i++)
{
e[i].clear();
for(int j=0;j<=20;j++)
{
ca[i][j]=0;
}
}
}
int main()
{
scanf("%d%d",&n,&m);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d%lld",&edge[i].u,&edge[i].v,&edge[i].w);
v[i]=edge[i].w;
}
sort(v+1,v+n);
len = unique(v+1,v+n) - (v+1);
init();
for(int i=1;i<=n-1;i++)
{
int u = edge[i].u,vv = edge[i].v;
int p = lower_bound(v+1,v+1+len,edge[i].w) - (v);
e[u].push_back(Edge(vv,p));
e[vv].push_back(Edge(u,p));
}
cnt=0;
root[0] = build(1,len,0);
lca(1,0);
while(m--)
{
int u,vv;
ll k;
scanf("%d%d%lld",&u,&vv,&k);
int p = upper_bound(v+1,v+1+len,k) - (v);
p--;
if(p==0)
{
printf("0\n");
continue;
}
int fa = get_lca(u,vv);
printf("%d\n",query(1,len,root[u],root[vv],root[fa],p));
}
return 0;
}