Tree
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v. Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. Write a program that will count how many pairs which are valid for a given tree. Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros. Output
For each test case output the answer on a single line.
Sample Input 5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0 Sample Output 8 Source |
解题报告:楼教主男人八题之一。树分治。
我觉得难点在于,在你计算复杂度之前,你肯定不会想到这么做。
每次同此当前每个节点子树的节点数,找到重心。这个复杂度为O(n)。统计重心到所有节点的距离,找到所有长度小于等于K的链。这里可以先排序,用O(n)的算法找到所有解。排序复杂度O(n log n)。删除重心,去重,递归下去。每次找的都是重心,可以保证递归的深度不超过log n。故总复杂度为O(n log n log n)。
问题就解决啦……代码如下:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <string>
#include <iomanip>
#include <cassert>
using namespace std;
#pragma comment(linker, "/STACK:1024000000,1024000000")
#define ff(i, n) for(int i=0;i<(n);i++)
#define fff(i, n, m) for(int i=(n);i<=(m);i++)
#define dff(i, n, m) for(int i=(n);i>=(m);i--)
#define travel(e, u) for(int e = u, v = vv[u]; e; e = nxt[e], v = vv[e])
#define bit(n) (1LL<<(n))
typedef long long LL;
typedef unsigned long long ULL;
void work();
int main()
{
#ifdef ACM
freopen("in.txt", "r", stdin);
#endif // ACM
work();
}
void scanf(int & x, char ch = 0)
{
while((ch = getchar()) < '0' || ch > '9');
x = ch - '0';
while((ch = getchar()) >= '0' && ch <= '9') x = 10 * x + (ch - '0');
}
/***************************************************************************************/
const int maxv = 11111;
int n, k;
int ans;
int edge[maxv], ecnt;
int nxt[maxv * 2], vv[maxv * 2], ww[maxv * 2];
bool vis[maxv];
int siz[maxv], mson[maxv];
int mi, root;
int tot, dis[maxv];
void init()
{
ans = 0;
ecnt = 2;
memset(edge, 0, sizeof(edge));
memset(vis, 0, sizeof(vis));
}
void addEdge(int u, int v, int w, int first[])
{
nxt[ecnt] = first[u], vv[ecnt] = v, ww[ecnt] = w, first[u] = ecnt++;
}
void dfsSize(int u, int f)
{
siz[u] = 1;
mson[u] = 0;
travel(e, edge[u]) if(!vis[v] && v != f)
{
dfsSize(v, u);
siz[u] += siz[v];
mson[u] = max(mson[u], siz[v]);
}
}
void dfsGravity(int r, int u, int f)
{
mson[u] = max(mson[u], siz[r] - siz[u]);
if(mson[u] < mi) mi = mson[u], root = u;
travel(e, edge[u]) if(!vis[v] && v != f)
dfsGravity(r, v, u);
}
void dfsDis(int u, int f, int d)
{
dis[tot++] = d;
travel(e, edge[u]) if(!vis[v] && v != f)
dfsDis(v, u, d + ww[e]);
}
int calc(int u, int d = 0)
{
tot = 0;
dfsDis(u, 0, d);
sort(dis, dis + tot);
int ret = 0;
int l = 0, r = tot - 1;
while(l < r)
{
while(dis[l] + dis[r] > k && l < r) r--;
ret += r - l;
l++;
}
return ret;
}
void dfs(int u)
{
mi = n;
dfsSize(u, 0);
dfsGravity(u, u, 0);
ans += calc(root);
vis[root] = true;
travel(e, edge[root]) if(!vis[v])
{
ans -= calc(v, ww[e]);
dfs(v);
}
}
void work()
{
while(scanf("%d%d", &n, &k) == 2 && (n || k))
{
init();
ff(i, n-1)
{
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
addEdge(u, v, w, edge);
addEdge(v, u, w, edge);
}
dfs(1);
cout << ans << endl;
}
}