一 问题描述
一棵有 n 个节点的树,每条边都有一个长度(小于 1001 的正整数),dist(u , v ) 为节点 u 和 v 的最小距离。给定一个整数 k ,对每对节点(u , v ),当且仅当 dist(u , v )不超过 k 时才叫作有效。计算给定的树中有多少对节点是有效的。
二 输入和输出
1 输入
输入包含几个测试用例。每个测试用例的第 1 行都包含两个整数 n 、k (n ≤10000),下面的 n -1 行,每行都包含三个整数 u、v 、l ,表示节点 u 和 v 之间有一条长度为 l 的边。在最后一个测试用例后面跟着两个0。
2 输出
对每个测试用例,都单行输出答案。
三 输入和输出样例
1 输入样例
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
2 输出样例
8
四 分析
根据测试用例的输入数据,树形结构如下图所示。树中距离不超过 4 的有 8 对节点:1-2、1-3、1-4、1-5、2-3、3-4、3-5、4-5。
查询树中有多少对节点距离不超过 k ,相当于查询树上两点之间距离不超过 k 的路径有多少条。可采用点分治解决。当数据量很大时,树上两点之间的路径很多,采用暴力穷举的方法是不可行的,可以采用树上分治算法进行点分治。以树的重心 root 为划分点,则树上两点 u 、v 的路径分为两种:① 经过root;② 不经过root(两点均在 root 的一棵子树中),只需求解第1类路径,对第2类路径根据分治策略继续采用重心分解即可得到。
五 设计
1 求树的重心 root。
2 从树的重心 root 出发,统计每个节点到 root 的距离。
3 对距离数组排序,以双指针扫描,统计以 root 为根的子树中满足条件的节点数。
4 对root的每一棵子树 v 都减去重复统计的节点数。
5 从 v 出发重复上述过程。
六 图解
求解树上两点之间距离(路径长度)不超过 4 的路径数。
1 求解树的重心,root =1。
2 从树的重心 root 出发,统计每个节点到 root 的距离,得到距离数组 dep[]。
3 对距离数组进行非递减排序,结果如下图所示。然后以双指针扫描,统计以 root 为根的子树中满足条件的节点对数。
a L =1,R =7,因为 dep[L ] + dep[R ] > 4,则 R--。
b L =1,R = 5,dep[L ]+dep[R ] <= 4,则 ans += R - L = 4,L++。
c L = 2,R = 5,若 dep[L] + dep[R] ≤ 4,则 ans += R -L = 4+3 = 7,L++。
d L = 3,R = 5,因为 dep[L]+dep[R] > 4,则 R--。
e L = 3,R = 4,因为 dep[L] + dep[R] ≤ 4,则 ans += R - L = 7+1 = 8,L++,此时 L = R ,算法停止。
也就是说,以 1 为根的树,满足条件的路径数有 8 个。在这些路径中,有些是合并路径,例如两条路径 1-2 和 1-3,其路径长度之和为 4,满足条件。这相当于将两条路径合并为 2-1-3,路径长度为 4。
路径长度小于或等于 4 的 8 条路径如下表所示。
第 7 条路径的合并是错误的。路径 1-3 和路径 1-3-7 的路径长度之和虽然小于或等于 4,但是不可以作为合并路径,因为树中任意两个节点之间的路径都是不重复的。而路径 1-3 和路径 1-3-7 之间的路径有重复,所以这样的路径不可以作为合并路径。可以先统计该路径,然后在处理以 3 为根的子树时去重。
4 对 root 的每一棵子树 v 都先去重,然后求以 v 为根的子树的重心,重复上述过程。
5 去重
去重。以 2 为根的子树没有重复统计的路径。
6 以 2 为根的子树的重心为 2,该子树满足条件的路径有3条,ans +=3=8+3=11,这 3 条为 2-5、2-6、2-5, 5-6(相当于一条合并路径 5-2-6,路径长度为 4 )。
7 去重。以 3 为根的子树,该子树有一条重复统计的路径(1-3 和 1-3-7 的合并路径)。减去重复路径,ans-1 = 10。
8 以 3 为根的子树的重心为 3,该子树满足条件的路径有 1 条(3-7),路径长度为 1,ans+1 = 11。
9 以 4 为根的子树的重心为 4,该子树没有重复统计的路径,也没有满足条件的路径。
七 算法实现
1 求树的重心
只需进行一次深度优先遍历,找到删除该节点后最大子树最小的节点。用 f[u] 表示删除 u 后最大子树的大小,size[u]表示以 u 为根的子树的节点数,S 表示整棵子树的节点数。先统计 u 的所有子树中最大子树的节点数 f[u ],然后与 S-size[u ]比较,取最大值。
若 f[u] < f[root],则更新当前树的重心为 root=u 。
2 统计每个节点到重心 u 的距离。把 dep[0]当作计数器使用,初始化为0,深度优先遍历,将每个节点到 u 的距离 d[] 都存入dep 数组中。
3 统计重心 u 的子树中满足条件的个数。初始化 d[u]=dis 且 dep[0]=0(用于计数),将每个节点到 u 的距离 d[] 都存入 dep 数组中;然后对 dep 数组排序,L=1,R=dep[0](dep数组末尾的下标),用 sum 累加满足条件的节点对个数。
4 对重心 u 的所有子树都先去重,然后递归求解答案。对 u 的每一棵子树 v 都减去 v 中重复统计的答案,然后从 v 出发重复上述过程。
八 代码
package com.platform.modules.alg.alglib.poj1741;
import java.util.Arrays;
public class Poj1741 {
private int maxn = 10005;
int cnt, n, k, ans;
int head[] = new int[maxn];
int root, S;
int size[] = new int[maxn];
int f[] = new int[maxn];
int d[] = new int[maxn];
int dep[] = new int[maxn];
boolean vis[] = new boolean[maxn];
edge edge[] = new edge[maxn * 2];
public String output = "";
public Poj1741() {
for (int i = 0; i < edge.length; i++) {
edge[i] = new edge();
}
}
// 获取重心
void getroot(int u, int fa) {
size[u] = 1;
f[u] = 0; // 删除 u 后,最大子树的大小
for (int i = head[u]; i > 0; i = edge[i].next) {
int v = edge[i].to;
if (v != fa && !vis[v]) {
getroot(v, u);
size[u] += size[v];
f[u] = Math.max(f[u], size[v]);
}
}
f[u] = Math.max(f[u], S - size[u]); // S为当前子树总结点数
if (f[u] < f[root])
root = u;
}
// 获取距离
void getdep(int u, int fa) {
dep[++dep[0]] = d[u]; // 保存距离数组
for (int i = head[u]; i > 0; i = edge[i].next) {
int v = edge[i].to;
if (v != fa && !vis[v]) {
d[v] = d[u] + edge[i].w;
getdep(v, u);
}
}
}
// 获取 u 的子树中满足个数
int getsum(int u, int dis) {
d[u] = dis;
dep[0] = 0;
getdep(u, 0);
Arrays.sort(dep, 1, 1 + dep[0]);
int L = 1, R = dep[0], sum = 0;
while (L < R)
if (dep[L] + dep[R] <= k) {
sum += R - L;
L++;
} else
R--;
return sum;
}
// 获取答案
void solve(int u) {
vis[u] = true;
ans += getsum(u, 0);
for (int i = head[u]; i > 0; i = edge[i].next) {
int v = edge[i].to;
if (!vis[v]) {
ans -= getsum(v, edge[i].w);//减去重复
root = 0;
S = size[v];
getroot(v, u);
solve(root);
}
}
}
void add(int u, int v, int w) {
edge[++cnt].to = v;
edge[cnt].w = w;
edge[cnt].next = head[u];
head[u] = cnt;
}
public String cal(String input) {
f[0] = 0x7fffffff; // 初始化树根
String[] line = input.split("\n");
String[] num = line[0].split(" ");
n = Integer.parseInt(num[0]);
k = Integer.parseInt(num[1]);
cnt = 0;
ans = 0;
for (int i = 1; i <= n - 1; i++) {
String[] edge = line[i].split(" ");
int x, y, z;
x = Integer.parseInt(edge[0]);
y = Integer.parseInt(edge[1]);
z = Integer.parseInt(edge[2]);
add(x, y, z);
add(y, x, z);
}
root = 0;
S = n;
getroot(1, 0);
solve(root);
output = ans + "";
return output;
}
}
class edge {
int to, next, w;
}