Tree
Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 8103 | Accepted: 2375 |
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
Source
题意:给出一棵树,找到所有两个点之间距离是小于等于k的对的个数
思路:这道题在这里的解题报告的帮助下完成的http://blog.csdn.net/woshi250hua/article/details/7723400 ,对于两个点都可以看成是从一个点经过这个点的一个根到另一个点(另一个点可以就是这个根),所以统计的时候先把从根出发到每个儿子的距离(距离<=k) 算出来,然后把这些距离拿出来排个序,找到dist[i]+dis[j]<=k 的个数,然后为了避免重复,我们只取经过该根的对,其余的减掉, 然后我们继续递归到子树中进行同样的操作。 这里有个要注意的问题就是每一次进入一颗子树中,我们都要为这棵子树找到一个重心,否则在比较坏的情况下会超时(如树是一条链)具体实现请看代码。
代码:
#include<iostream>
#include<cstdio>
#include<string.h>
#include<algorithm>
#include<string>
#include<deque>
#include<queue>
#include<math.h>
#include<vector>
#include<map>
#include<stack>
#include<set>
using namespace std;
const int MAX = 10000+10;
#define MOD 99997
const int inf = 0xffffff;
#define max(a,b) (a) < (b) ? (b) : (a)
int k , m , tot , ans;
bool vis[MAX];
int dist[MAX];
int sign[MAX];
int size[MAX];
struct Node
{
int v;
int len;
Node *next;
int sum;
int bal;
}*first[MAX] , edge[MAX*2] , tp[MAX];
void init()
{
memset(vis,0,sizeof(vis));
memset(first,0,sizeof(first));
m = ans = 0;
}
void add(int x,int y,int c)
{
edge[++m].v = y;
edge[m].len = c;
edge[m].next = first[x];
first[x] = &edge[m];
}
void GetDist(int son,int fa,int dis)
{
Node *p = first[son];
dist[tot++] = dis;
while (p)
{
if (p->v!=fa && !vis[p->v] && dis+p->len<=k)
{
GetDist(p->v,son,dis+p->len);
}
p = p->next;
}
}
void dfs(int son,int fa)
{
tp[son].sum = 1;
tp[son].bal = 0;
Node *p = first[son];
while (p)
{
if (p->v!=fa && !vis[p->v])
{
dfs(p->v,son);
tp[son].sum += tp[p->v].sum;
tp[son].bal = max(tp[son].bal,tp[p->v].sum);
}
p = p->next;
}
size[tot] = tp[son].bal;
sign[tot++] = son;
}
int GetRoot(int son,int fa)
{
tot = 0 , dfs(son,fa);
int rt , min_sz = inf;
for (int i = 0 ; i < tot ; ++i)
{
size[i] = max(size[i],tp[son].sum-size[i]);
if (size[i] < min_sz)
{
min_sz = size[i];
rt = sign[i];
}
}
return rt;
}
int Bisearch(int left,int right,int x)
{
int mid = (left+right)>>1;
while (left<=right)
{
if (dist[mid]<=x) left = mid+1;
else right = mid-1;
mid = (left+right)>>1;
}
return left;
}
int getPair()
{
int ret = 0;
sort(dist,dist+tot);
for (int i = 0 ; i < tot ; ++i)
{
int p = Bisearch(i+1,tot-1,k-dist[i])-1;
if (dist[p]+dist[i] > k) break;
ret += p-i;
}
return ret;
}
void Solve(int son,int fa)
{
int rt = GetRoot(son,fa);
vis[rt] = true;
tot = 0 , GetDist(rt,fa,0);
ans += getPair();
Node *p = first[rt];
while (p)
{
if (p->v!=fa && !vis[p->v])
{
tot = 0 , GetDist(p->v,rt,p->len);
ans -= getPair();
}
p = p->next;
}
p = first[rt];
while (p)
{
if (p->v!=fa && !vis[p->v])
Solve(p->v,rt);
p = p->next;
}
}
int main()
{
int n;
while (scanf("%d%d",&n,&k) , n || k)
{
init();
while (--n)
{
int x,y,c;
scanf("%d%d%d",&x,&y,&c);
add(x,y,c);
add(y,x,c);
}
Solve(1,-1);
printf("%d\n",ans);
}
}