题目描述
给定一棵 n n n 个节点的树,每条边有边权,求出树上两点距离小于等于 k k k 的点对数量。
输入格式
第一行输入一个整数 n n n,表示节点个数。
第二行到第 n n n 行每行输入三个整数 u , v , w u,v,w u,v,w,表示 u u u 与 v v v 有一条边,边权是 w w w。
第 n + 1 n+1 n+1 行一个整数 k k k 。
输出格式
一行一个整数,表示答案。
输入输出样例
输入
7
1 6 13
6 3 9
3 5 7
4 1 3
2 4 20
4 7 2
10
输出
5
说明/提示
对于全部的测试点,保证:
- 1 ≤ n ≤ 4 × 1 0 4 1\leq n\leq 4\times 10^4 1≤n≤4×104 。
- 1 ≤ u , v ≤ n 1\leq u,v\leq n 1≤u,v≤n
- 0 ≤ w ≤ 1 0 3 0\leq w\leq 10^3 0≤w≤103
- 0 ≤ k ≤ 2 × 1 0 4 0\leq k\leq 2\times 10^4 0≤k≤2×104
点分治步骤
- 找到树的重心
- 将重心视为根节点,那么树上任意两点有两种情况
- 路径经过根节点
- 路径不经过根节点
- 通过 c a l c calc calc 函数计算出第一种情况下的答案,把根节点从树中删去
- 对每棵子树执行上面的操作
calc函数的计算方法
- 计算出每个结点到根节点的距离 d [ i ] d[i] d[i]
- 将树上的结点按照 d [ i ] d[i] d[i] 递增排序
- 指针 l l l 指向 d [ 1 ] d[1] d[1] ,指针 r r r 指向 d [ n ] d[n] d[n] 。
- 若 l l l 与 r r r 指向结点的距离小于 k k k ,则 a n s + = r − l + 1 , l + + ans+=r-l+1,l++ ans+=r−l+1,l++。
- 否则 r − − r-- r−−。当 l > = r l>=r l>=r 的时候退出循环。
按照上面的方法,会把不经过根节点的路径也算入进去。利用容斥原理修正答案:
ans-=calc(y,edge[i]);
c a l c calc calc 函数的第一个参数为树的根节点,第二个参数为附加距离。
#include<bits/stdc++.h>
using namespace std;
const int N=4e4+10;
int n,k,u,v,w,root,ans,vis[N],mxSize,Size[N],len[N],d[N];
int head[N],ver[N<<1],Next[N<<1],edge[N<<1],tot;
void add(int u,int v,int w){
ver[++tot]=v,edge[tot]=w,Next[tot]=head[u],head[u]=tot;
}
void getRoot(int x,int fa){
Size[x]=1; int mx=0;
for(int i=head[x];i;i=Next[i]){
int y=ver[i];
if(vis[y]||y==fa) continue;
getRoot(y,x);
Size[x]+=Size[y];
mx=max(mx,Size[y]);
}
mx=max(mx,n-Size[x]);
if(mx<mxSize) root=x,mxSize=mx;
}
void getDis(int x,int fa){
len[++len[0]]=d[x];
for(int i=head[x];i;i=Next[i]){
int y=ver[i];
if(vis[y]||y==fa) continue;
d[y]=d[x]+edge[i];
getDis(y,x);
}
}
int calc(int x,int w){
d[x]=w; len[0]=0; int ret=0;
getDis(x,x);
sort(len+1,len+len[0]+1);
for(int l=1,r=len[0];l<r;){
if(len[l]+len[r]<=k) ret+=r-l,l++;
else r--;
}
return ret;
}
void solve(int x){
vis[x]=1;
ans+=calc(x,0);
for(int i=head[x];i;i=Next[i]){
int y=ver[i];
if(vis[y]) continue;
ans-=calc(y,edge[i]);
mxSize=n,getRoot(y,y);
solve(root);
}
}
int main(){
scanf("%d",&n);
ans=tot=0;
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w),add(v,u,w);
}
scanf("%d",&k);
mxSize=n,getRoot(1,1);
solve(root);
printf("%d\n",ans);
}