Tree
Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 27073 | Accepted: 9003 |
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
Source
给定一棵树,每条边有距离。
求节点间距离小于k的点对数量。
树分治。对于每个点,分别求经过这个点的,满足条件的路径数量。
具体而言,对于一个固定的节点,先dfs求出所有未访问节点到当前节点的距离,并排序,再利用two-pointer O(n)求出所有距离和<=k的点对数量。此时多算了路径起点终点都在同一棵子树内的情况,为了消除这一部分,再对每个子树dfs求出多出的部分是多少。
#include <cstdio>
#include <iostream>
#include <string.h>
#include <string>
#include <map>
#include <queue>
#include <vector>
#include <set>
#include <algorithm>
#include <math.h>
#include <cmath>
#include <bitset>
#define mem0(a) memset(a,0,sizeof(a))
#define meminf(a) memset(a,0x3f,sizeof(a))
using namespace std;
typedef long long ll;
typedef long double ld;
const int maxn=200005,inf=0x3f3f3f3f;
const ll llinf=0x3f3f3f3f3f3f3f3f;
const ld pi=acos(-1.0L);
int head[maxn],size[maxn],ms[maxn],d[maxn];
int num=0,root=-1,rs=inf,sum,k,ans,cnt=0;
bool visit[maxn];
char s[maxn];
struct Edge {
int from,to,pre,dist;
};
Edge edge[maxn*2];
void addedge(int from,int to,int dist) {
edge[num]=(Edge){from,to,head[from],dist};
head[from]=num++;
edge[num]=(Edge){to,from,head[to],dist};
head[to]=num++;
}
void getroot(int now,int fa) {
size[now]=ms[now]=0;
for (int i=head[now];i!=-1;i=edge[i].pre) {
int to=edge[i].to;
if (!visit[to]&&to!=fa) {
getroot(to,now);
size[now]+=size[to];
ms[now]=max(ms[now],size[to]);
}
}
size[now]++;
ms[now]=max(ms[now],sum-size[now]);
if (ms[now]<rs) root=now,rs=ms[now];
}
void dfs(int now,int fa,int dis) {
d[++cnt]=dis;
for (int i=head[now];i!=-1;i=edge[i].pre) {
int to=edge[i].to;
if (!visit[to]&&to!=fa) dfs(to,now,dis+edge[i].dist);
}
}
int cal(int now,int fa,int dis) {
cnt=0;
dfs(now,0,dis);
sort(d+1,d+cnt+1);
int ssum=0,i,l=1,r=cnt;
while (l<r) {
if (d[l]+d[r]<=k) {
ssum+=r-l;
l++;
} else r--;
}
return ssum;
}
void solve(int now) {
visit[now]=1;
ans+=cal(now,0,0);
for (int i=head[now];i!=-1;i=edge[i].pre) {
int to=edge[i].to;
if (!visit[to])
ans-=cal(to,now,edge[i].dist);
}
for (int i=head[now];i!=-1;i=edge[i].pre) {
int to=edge[i].to;
if (!visit[to]) {
root=-1,rs=inf;
sum=size[to];
getroot(to,0);
solve(root);
}
}
}
int main() {
int i,j,x,y,z,n;
scanf("%d%d",&n,&k);
while (n||k) {
num=ans=0;
memset(head,-1,sizeof(head));
for (i=1;i<n;i++) {
scanf("%d%d%d",&x,&y,&z);
addedge(x,y,z);
}
sum=n;
getroot(1,0);
mem0(visit);
solve(root);
printf("%d\n",ans);
scanf("%d%d",&n,&k);
}
return 0;
}