Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
题意:求这棵树中,两点之间距离小于k的点对数。
solution:点分治。
介绍一下点分治:
点分治就是选取一个点将无根树转化为有根树,再递归处理以根节点的儿子节点为根的子树。
点分治的步骤:
1.将无根树转化为有根树。
可以证明,以重心为根,每棵树子结点个数均不大于n/2。因此,递归深度不超过(logn)
找重心的方法——树型DP
void getroot(int u, int fa){// u表示当前节点 fa表示u的父亲节点
son[u]=1; f[u]=0;
// son记录u的子节点数 f记录以u为根最大子结点大小
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || vis[v] ) continue;
getroot(v, u);
son[u]+=son[v];
f[u]=max(f[u], son[v]);
}
f[u]=max(f[u], sum-son[u]);
//如果以u为根,u的父亲节点会变成u的子节点
//sum-son[u]表示以u的父亲节点为子结点的节点数
if( f[u]<f[root] ) root=u;
}
2. 统计过根节点的路径(过根节点的所有路径-相交的路径)
在以1号节点为根节点时,计算出了6,7号节点的距离=1号路径(绿色)+2号路径(绿色)——但事实上6,7号节点的距离=3号路径(红色)的长度。因此我们要减去1,2号路径
3.删除根节点。
4.递归子树。
完整代码
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 1e4 + 5;
#define Inf 0x7fffffff
inline int read(){
int x=0, f=1; char ch=getchar();
while( !isdigit(ch) ) {
if( ch=='-' ) f=-1;
ch=getchar();
}
while( isdigit(ch) ){
x=x*10+ch-'0';
ch=getchar();
}
return x*f;
}
struct Edge{
int nxt, to, w;
}e[N<<1];
int head[N], son[N], f[N], dis[N], d[N], vis[N];
int tot, root, n, sum, k, ans;
void addeage(int u, int v, int w){
e[++tot].nxt=head[u], e[tot].to=v, e[tot].w=w;
head[u]=tot;
}
void getroot(int u, int fa){
son[u]=1; f[u]=0;
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || vis[v] ) continue;
getroot(v, u);
son[u]+=son[v];
f[u]=max(f[u], son[v]);
}
f[u]=max(f[u], sum-son[u]);
if( f[u]<f[root] ) root=u;
}
void getdis(int u, int fa){
d[++tot]=dis[u];
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || vis[v] ) continue;
dis[v]=dis[u]+e[i].w;
getdis(v, u);
}
}
int calc(int u, int ds){
dis[u]=ds; tot=0;
getdis(u, -1);
sort(d+1,d+1+tot);
int l=1, r=tot, sum=0;
while( l<r ){
if( d[l]+d[r]<=k ) sum+=r-l, l++;
else r--;
}
return sum;
}
void solve(int u){
ans+=calc(u, 0);
vis[u]=1;
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( vis[v] ) continue;
ans-=calc(v,e[i].w);
sum=son[v];
root=0;
getroot(v, -1);
solve(root);
}
}
int main(){
while( ~scanf("%d%d", &n, &k ) && n ){
ans=0, root=0, tot=0;
memset(vis,0,sizeof(vis));
memset(head,0,sizeof(head));
for ( int i=1; i<n; i++ ){
int x, y, z;
x=read(), y=read(), z=read();
addeage(x,y,z), addeage(y,x,z);
}
f[0]=Inf, sum=n;
getroot(1,-1);
solve(root);
printf("%d\n", ans );
}
}