Time Limit: 1000MS | Memory Limit: 30000K | |
Total Submissions: 22767 | Accepted: 7527 |
Description
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The last test case is followed by two zeros.
Output
Sample Input
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
Sample Output
8
题意:一棵有n个节点的树,每条边有个权值代表相邻2个点的距离,要求求出所有距离不超过k的点对(u,v)
题解:树分治
假设树以root为根节点,那么满足要求的点对有2种情况:
①路径经过root且dis(u,v)<=k
②路径不经过root,即其路径的最高点为子树上某一节点
对于第②种情况可以通过递归求解,这里只讨论第一种情况
该如何求解路径经过root且dis(u,v)<=k的合法点对数呢?
设dir[u]为u到根节点root的距离,那么只有满足dir[u]+dir[v]<=k且LCA(u,v)==root的点对才是合法的,
设cnt1=树中所有dis(u,v)<=k的点对数,cnt2=LCA(u,v)==root的子节点的合法点对数
那么以root为根的树种合法点对数为:ans=cnt1-cnt2
找出有多少个dir[u]+dir[v]的方法很简单:只需要排序后扫一遍即可。
总结一下算法的过程:
①计算以u为根的树种每棵子树的大小
②根据子树大小找出树的重心root(以树的重心为根的树,可以使其根的子树中节点最多的子树的节点最少)
③以root为根,计算树中每个点到root的距离dir
④计算树中所有满足dir[u]+dir[v]<=k的点对数cnt1
⑤计算以root的子节点为根的子树中,满足dir[u]+dir[v]<=k的点对数cnt2
⑥ans+=cnt1-cnt2
注意:每次计算完cnt1后,要将vis[root]=1,这样就可以将一棵树分解成若干棵子树
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MX = 1e4 + 5;
struct Edge {
int v, w, nxt;
} E[MX * 2];
int n, k, root, Max, ans;
vector <int> dis;
int sz[MX], maxv[MX], head[MX], tot;
bool vis[MX];
void init() {
memset(vis, false, sizeof(vis));
memset(head, -1, sizeof(head));
tot = 0;
}
void add(int u, int v, int w) {
E[tot].v = v;
E[tot].w = w;
E[tot].nxt = head[u];
head[u] = tot++;
}
void dfs_size(int u, int fa) {
sz[u] = 1; maxv[u] = 0;
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (vis[v] || v == fa) continue;
dfs_size(v, u);
sz[u] += sz[v];
maxv[u] = max(maxv[u], sz[v]);
}
}
void dfs_root(int r, int u, int pre) { // 找出以u为根的子树的重心
maxv[u] = max(maxv[u], sz[r] - sz[u]);
if (Max > maxv[u]) {
Max = maxv[u];
root = u;
}
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v;
if (v == pre || vis[v]) continue;
dfs_root(r, v, u);
}
}
void dfs_dis(int u, int fa, int dir) {
dis.push_back(dir);
for (int i = head[u]; ~i; i = E[i].nxt) {
int v = E[i].v, w = E[i].w;
if (vis[v] || v == fa) continue;
dfs_dis(v, u, dir + w);
}
}
int cal(int rt, int d) {
dis.clear();
dfs_dis(rt, -1, d);
sort(dis.begin(), dis.end());
int i = 0, j = dis.size() - 1, ret = 0;
while (i < j) {
while (dis[i] + dis[j] > k && i < j) j--;
ret += j - i;
i++;
}
return ret;
}
void DFS(int u) {
Max = n;
dfs_size(u, -1);
dfs_root(u, u, -1);
int rt = root;
ans += cal(rt, 0);
vis[rt] = 1;
for (int i = head[rt]; ~i; i = E[i].nxt) {
int v = E[i].v, w = E[i].w;
if (vis[v]) continue;
ans -= cal(v, w);
DFS(v);
}
}
int main() {
//freopen("in.txt","r",stdin);
while (scanf("%d%d", &n, &k), n || k) {
init();
for (int i = 1; i < n; i++) {
int u, v, w;
scanf("%d%d%d", &u, &v, &w);
add(u, v, w); add(v, u, w);
}
ans = 0;
DFS(1);
printf("%d\n", ans);
}
return 0;
}