【题目】
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
【分析】
题目大意:(多组数据)给出一棵边带权树,求出这棵树中距离不超过 k k k 的点对的数量
题解:点分治模板题
由于这是我的第一道点分治题,我还是好好写一下博客吧
先假设这是一道有根树,那满足条件的点对必然是以下两种情况:
- 它们的路径经过根节点
- 它们的路径不经过根节点(也就是说它们在同一个子树中)
对于 2,可以把它当成子问题,递归求解,现在就是讨论如何求出 1
假设 d i s i dis_i disi 为 i i i 到根的路径长度,用 d f s dfs dfs 求出所有点到根的距离,然后对所有 d i s dis dis 排序,这样就便于统计 d i s x + d i s y ≤ k dis_x+dis_y≤k disx+disy≤k 的总数,但这样做我们用把 2 的部分情况考虑进去,还要减掉这些情况
怎么选这个根呢,考虑用重心,因为减去重心后,子树的 s i z e size size 都会减少一半,这样可以保证复杂度
递归层数 O( l o g    n log\;n logn), s o r t sort sort 是 O( n ∗ l o g    n n * log\;n n∗logn),总复杂度是O( n ∗ l o g 2    n n*log^2\;n n∗log2n)
【代码】
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 50005
#define inf (1ll<<31ll)-1
using namespace std;
int n,k,t,ans,num,root,sum;
int d[N],size[N],Max[N];
int first[N],v[N],w[N],next[N];
bool vis[N];
void add(int x,int y,int z)
{
t++;
next[t]=first[x];
first[x]=t;
v[t]=y;
w[t]=z;
}
void dfs(int x,int father)
{
int i,j;
Max[x]=0;
size[x]=1;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
{
dfs(j,x);
size[x]+=size[j];
Max[x]=max(Max[x],size[j]);
}
}
}
void find(int rt,int x,int father)
{
int i,j;
Max[x]=max(Max[x],size[rt]-size[x]);
if(num>Max[x]) num=Max[x],root=x;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
find(rt,j,x);
}
}
void dist(int x,int father,int len)
{
int i,j;
d[++sum]=len;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
dist(j,x,len+w[i]);
}
}
int calc(int x,int l)
{
sum=0,dist(x,0,l);
sort(d+1,d+sum+1);
int ans=0,i=1,j=sum;
while(i<j)
{
while(d[i]+d[j]>k&&i<j) j--;
ans+=j-i;i++;
}
return ans;
}
void solve(int x)
{
int i,j;
dfs(x,0);
num=inf,find(x,x,0);
ans+=calc(root,0);
vis[root]=true;
for(i=first[root];i;i=next[i])
{
j=v[i];
if(!vis[j])
{
ans-=calc(j,w[i]);
solve(j);
}
}
}
int main()
{
int x,y,z,i;
while(~scanf("%d%d",&n,&k))
{
ans=0,t=0;
if(!n&&!k) break;
memset(first,0,sizeof(first));
memset(vis,false,sizeof(vis));
for(i=1;i<n;++i)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
solve(1);
printf("%d\n",ans);
}
return 0;
}