Tree
Time Limit: 1000MS Memory Limit: 30000K Total Submissions: 26911 Accepted: 8953
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.Output
For each test case output the answer on a single line.
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0Sample Output
8Source
【思路】
给定一棵树,边上有权,要求的是树上距离不超过k的点对数。
我们假设一棵有根树,维护树中的点到根的距离d,树中的任何路径必然有两种,1、经过根,2、不经过根。经过根的那些路径,其两端点u、v必然满足d[u] + d[v] <= k,不经过根的路径,也可以设根为子树中某一点来递归处理,问题解决。
子树中的路径加和判断可以通过O(NlogN)排序然后O(N)算出,那么我们的首要任务就变成了尽可能减少排序的次数,树形结构决定了排序次数和树的层数一致,所以也就需要通过寻找树的重心来减少树的层数,使整个算法复杂度在O(NlogN*logN)级别。
【代码】
//************************************************************************
// File Name: POJ_1741.cpp
// Author: Shili_Xu
// E-Mail: shili_xu@qq.com
// Created Time: 2018年02月26日 星期一 21时49分18秒
//************************************************************************
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int MAXN = 10005;
struct edge {
int to, len;
edge(int _to, int _len) : to(_to), len(_len) {}
};
int n, k, root, size;
int sz[MAXN], mxson[MAXN], d[MAXN];
bool visited[MAXN];
vector<edge> g[MAXN];
vector<int> dist;
void get_root(int u, int fa)
{
sz[u] = 1; mxson[u] = 0;
for (int i = 0; i < g[u].size(); i++) {
int v = g[u][i].to;
if (v != fa && !visited[v]) {
get_root(v, u);
sz[u] += sz[v];
mxson[u] = max(mxson[u], sz[v]);
}
}
mxson[u] = max(mxson[u], size - sz[u]);
if (mxson[u] < mxson[root]) root = u;
}
void get_dist(int u, int fa)
{
dist.push_back(d[u]);
for (int i = 0; i < g[u].size(); i++) {
int v = g[u][i].to;
if (v != fa && !visited[v]) {
d[v] = d[u] + g[u][i].len;
get_dist(v, u);
}
}
}
int cal(int u, int base)
{
dist.clear();
d[u] = base;
get_dist(u, 0);
sort(dist.begin(), dist.end());
int ans = 0, l = 0, r = dist.size() - 1;
while (l < r) {
if (dist[l] + dist[r] <= k)
ans += (r - l), l++;
else
r--;
}
return ans;
}
int work(int u)
{
visited[u] = true;
int ans = 0;
ans += cal(u, 0);
for (int i = 0; i < g[u].size(); i++) {
int v = g[u][i].to;
if (!visited[v]) {
ans -= cal(v, g[u][i].len);
root = 0; size = mxson[0] = sz[v];
get_root(v, 0);
ans += work(root);
}
}
return ans;
}
int main()
{
while (scanf("%d %d", &n, &k) == 2 && n != 0 && k != 0) {
for (int i = 1; i <= n; i++) g[i].clear();
for (int i = 1; i <= n - 1; i++) {
int a, b, c;
scanf("%d %d %d", &a, &b, &c);
g[a].push_back(edge(b, c));
g[b].push_back(edge(a, c));
}
int ans = 0;
root = 0; size = mxson[0] = n;
memset(visited, false, sizeof(visited));
get_root(1, 0);
ans = work(root);
printf("%d\n", ans);
}
return 0;
}