Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
题目大意
给定一棵带正边权的无根树,求链长不超过k的链的数量。
题解
先点分,考虑所有经过重心g的链,求出当前子树中所有点与g的距离,排序后求出和小于k的点对数,减掉属于u的同一个孩子的点对数再递归即可。
复杂度
O(nlog2n)
。
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
const int inf = 0x3f3f3f3f;
inline int read(){
int x = 0, f = 1; char c = getchar();
while(!isdigit(c)) { if(c == '-') f = -1; c = getchar(); }
while(isdigit(c)) { x = x * 10 + c - '0'; c = getchar(); }
return x * f;
}
const int N = 10000 + 10;
int n, k;
int tot, sum, ans, rt;
int hd[N], nxt[N<<1], val[N<<1], to[N<<1];
int f[N], siz[N], st[N], d[N];
bool vis[N];
void insert(int u, int v, int w){
to[++tot] = v; val[tot] = w; nxt[tot] = hd[u]; hd[u] = tot;
to[++tot] = u; val[tot] = w; nxt[tot] = hd[v]; hd[v] = tot;
}
void getroot(int u, int fa){
siz[u] = 1; f[u] = 0;
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i];
if(v == fa || vis[v]) continue;
getroot(v, u);
siz[u] += siz[v];
if(siz[v] > f[u]) f[u] = siz[v];
}
if(sum - siz[u] > f[u]) f[u] = sum - siz[u];
if(f[u] < f[rt]) rt = u;
}
void calst(int u, int fa){
st[++st[0]] = d[u];
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i];
if(v == fa || vis[v]) continue;
d[v] = d[u] + val[i];
calst(v, u);
}
}
int cal(int u, int pre){
d[u] = pre; st[0] = 0;
calst(u, 0);
sort(st+1, st+st[0]+1);
int tmp = 0;
for(int l = 1, r = st[0]; l < r;)
if(st[l] + st[r] <= k) tmp += r - l++;
else r--;
return tmp;
}
void solve(int u){
ans += cal(u, 0);
vis[u] = 1;
for(int i = hd[u]; i; i = nxt[i]){
int v = to[i];
if(vis[v]) continue;
ans -= cal(v, val[i]);
sum = siz[v];
rt = 0;
getroot(v, rt);
solve(rt);
}
}
bool init(){
ans = 0; rt = 0; tot = 0;
memset(vis, 0, sizeof(vis));
memset(hd, 0, sizeof(hd));
n = read(); k = read();
if(n == 0) return false;
for(int i = 1; i < n; i++){
int u = read(), v = read(), w = read();
insert(u, v, w);
}
return true;
}
void work(){
while(init()){
sum = n;
f[0] = inf;
getroot(1, 0);
solve(rt);
printf("%d\n", ans);
}
}
int main(){
work();
return 0;
}