问树上有多少点对之间路径边权
m
a
x
−
m
i
n
=
k
max-min=k
max−min=k,
k
k
k为定值。
k
≤
n
≤
2
∗
1
0
5
k\leq n\leq2*10^5
k≤n≤2∗105.
题解
其实这题比较套路,并不难想。
关于树上路径计数的问题,一般先考虑点分治能不能实现,发现是可以的。
按照一般点分治的套路,找到某个子树重心后,记录每个点到它的路径边权
m
a
x
,
m
i
n
max,min
max,min,有两种情况,一种是重心为路径的一端,直接枚举判断;另一种是重心在路径中间。
第二种情况,按
m
a
x
max
max从小到大排序,枚举一条路径和前面的另一条组合,
因为已经排好序了,所以
m
a
x
max
max一定在当前这条路径上,接着再分两种情况,一种是该路径的
m
a
x
−
m
i
n
<
k
max-min<k
max−min<k,那么查找前面
m
i
n
=
m
a
x
−
k
min=max-k
min=max−k的数量加入答案;一种是该路径的
m
a
x
−
m
i
n
=
k
max-min=k
max−min=k,则查找前面
m
i
n
≥
m
a
x
−
k
min\geq max-k
min≥max−k的数量加入答案。用树状数组维护。
#include<cstdio>#include<cstring>#include<algorithm>usingnamespace std;#define ll long long#define N 200010int n, K;
ll ans =0;int last[N], nxt[N *2], to[N *2], we[N *2], len =0;int vi[N], si[N], sum[N], s, rt, mi;int tot =0, f[N];struct node {int mx, mi, r;}a[N];voidadd(int x,int y,int w){
to[++len]= y;
we[len]= w;
nxt[len]= last[x];
last[x]= len;}voiddfs(int k,int fa){
si[k]=1;for(int i = last[k]; i; i = nxt[i])if(to[i]!= fa &&!vi[to[i]]){dfs(to[i], k);
si[k]+= si[to[i]];}}voidfind(int k,int fa){int mx = s - si[k];for(int i = last[k]; i; i = nxt[i])if(to[i]!= fa &&!vi[to[i]]){find(to[i], k);
mx =max(mx, si[to[i]]);}if(mx < mi) mi = mx, rt = k;}voiddfs1(int k,int fa,int t0,int t1,int r){if(t1) a[++tot].mx = t1, a[tot].mi = t0, a[tot].r = r;for(int i = last[k]; i; i = nxt[i])if(to[i]!= fa &&!vi[to[i]]){dfs1(to[i], k,min(t0, we[i]),max(t1, we[i]), r ==0? to[i]: r);}}intcmp(node x, node y){if(x.mx == y.mx)return x.mi < y.mi;return x.mx < y.mx;}intcmp1(node x, node y){return x.r < y.r;}intlow(int x){return x &(-x);}voidins(int k,int c){for(int i = k; i <= n; i +=low(i)) f[i]+= c;}intct(int k){int s =0;for(int i = k; i; i -=low(i)) s += f[i];return s;}voidds(int l,int r,int o){sort(a + l, a + r +1, cmp);for(int i = l; i <= r; i++){if(a[i].mx - a[i].mi == K){
ans +=(i - l -ct(a[i].mi -1))* o;}elseif(a[i].mx - a[i].mi < K) ans += sum[a[i].mx - K]* o;
sum[a[i].mi]++;ins(a[i].mi,1);}for(int i = l; i <= r; i++) sum[a[i].mi]--,ins(a[i].mi,-1);}voidcalc(int k){
tot =0;dfs1(k,0, n +1,0,0);sort(a +1, a + tot +1, cmp);for(int i =1; i <= tot; i++)if(a[i].mx - a[i].mi == K) ans++;ds(1, tot,1);sort(a +1, a + tot +1, cmp1);int la =1;for(int i =1; i <= tot; i++){if(i == tot || a[i].r != a[i +1].r){ds(la, i,-1);
la = i +1;}}}voidsolve(int k){dfs(k,0);
s = si[k], mi = n +1;find(k,0);calc(rt);
vi[rt]=1;for(int i = last[rt]; i; i = nxt[i])if(!vi[to[i]])solve(to[i]);}intmain(){int i, x, y, w;scanf("%d%d",&n,&K);for(i =1; i < n; i++){scanf("%d%d%d",&x,&y,&w);add(x, y, w),add(y, x, w);}solve(1);printf("%lld\n", ans);return0;}