今天来搞1741 Tree
男人进度:2/8
题目链接
http://poj.org/problem?id=1741
题目描述
Give a tree with n n n vertices,each edge has a length(positive integer less than 1001).
The defination of d i s t ( u , v ) dist(u,v) dist(u,v) is The min distance between node u u u and v v v.
Give an integer k k k,for every pair ( u u u, v v v) of vertices is called valid if and only if d i s t ( u , v ) dist(u,v) dist(u,v) not exceed k k k.
Write a program that will count how many pairs which are valid for a given tree.
输入
The input contains several test cases.
The first line of each test case contains two integers n n n, k k k. (n<=10000) The following n − 1 n-1 n−1 lines each contains three integers u u u, v v v, l l l, which means there is an edge between node u u u and v v v of length l l l.
The last test case is followed by two zeros.
输出
For each test case output the answer on a single line.
样例输入
5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0
样例输出
8
样例解释
样例给出的树如上图示,共有8条符合要求的边:
- 1->2,距离为 3
- 1->3,距离为 1
- 1->4,距离为 2
- 1->3->5,距离为 1+1=2
- 2->1->3,距离为 3+1=4
- 3->5,距离为 1
- 3->1->4,距离为 1+2=3
- 4->1->3->5,距离为 2+1+1=4
题解
知识点:树的重心
因为题目给出的是无向边,那么把哪个节点当作根节点都无所谓了。那么不妨设有一棵以节点1为根的树,长成下面这个样子:
仔细观察,我们可以依据是否包含根节点,把路径分为两类:
- 包含根节点(比如 1->2,4->1->5)
- 不包含根节点(比如 2->3, 5->6)
不难发现,不包含的根节点的路径,其所包含的节点必然都来自于同一棵子树。这就很符合递归的要求了,为了描述清楚,我们不妨先设:
- 以节点 i i i 为根的子树的合法路径总数为 t ( i ) t(i) t(i)
- 以节点 i i i 为根的子树的包含节点 i i i的合法路径总数为 f ( i ) f(i) f(i)
t
(
i
)
t(i)
t(i) 可以递归的定义为:
t
(
i
)
=
{
0
,
i
是
叶
子
∑
s
是
i
的
儿
子
t
(
s
)
+
f
(
i
)
,
i
不
是
叶
子
t(i)= \begin{cases} 0, i是叶子 \\ \sum_{s是i的儿子} t(s) + f(i), i不是叶子 \end{cases}
t(i)={0,i是叶子∑s是i的儿子t(s)+f(i),i不是叶子
问题简化了,答案可表示为 t ( 1 ) = ∑ i = 1 n f ( i ) t(1) = \sum_{i=1}^{n} f(i) t(1)=∑i=1nf(i)。那么如何求解 f ( i ) f(i) f(i) 咧?
对于每个节点 i i i,做对应子树遍历,计算出该子树中每个节点和节点 i i i 的距离。
比如有上面这样一棵树,为了方便,我们不妨设边的权值都为 1。那么每棵子树上的节点到对应根节点的距离如下图示:
需要注意的是,真实场景中的权值并不相等,若想得到关于距离升序的数组,不得不进行一次排序。
不管怎么说,现在我们得到 n n n 个距离的数组。我们不妨设距离的限制为 K K K。
现在对于每个数组,进行一次双指针遍历操作:
- 首先指针 l l l 指向第一个元素,指针 r r r 指向最后一个元素,累加器 sum = 0。
- 接着下述两个步骤,直到 l = r l = r l=r
- 如果 l + r < = K l+r <= K l+r<=K,那么 s u m + = ( r − l ) sum += (r-l) sum+=(r−l),并且 l + + l++ l++。
- 如果 l + r > K l+r > K l+r>K,那么 r − − r-- r−−。
当 l + r > K l+r>K l+r>K 时,说明第 l l l 个节点和第 r r r 个节点的距离超限了,又因为数组是递增的,所以 r − − r-- r−− 尝试用一个距根节点更近的节点与 l l l 匹配。
当 l + r < = K l+r<=K l+r<=K 时,意味着第 l l l 个节点,与第 l + 1 , l + 2 , . . . , r l+1,l+2,...,r l+1,l+2,...,r 个节点的距离均不超过 K K K,且此时 r + 1 r+1 r+1 与 l l l 的距离定然是超限的,所以 s u m + = ( r − l ) sum\ += (r-l) sum +=(r−l)。然后 l + + l++ l++ 继续计算下一个节点的贡献。
但此时引入了一个新问题,得出的 s u m sum sum 掺进了一些杂质,以节点 1 1 1 为例:
设
K
=
4
K=4
K=4,那么:
KaTeX parse error: No such environment: align* at position 8: \begin{̲a̲l̲i̲g̲n̲*̲}̲ sum = &\ 7-1 \…
其中类似于:
- 4->2->1->2->5
- 3->1->3->6
这种重复经过某些点的路径也加进去了。
这些路径可以通过下述方法剔除,以节点 1 1 1 为例:
- 首先计算出 s u m = 21 sum = 21 sum=21。
- 然后分别计算出两棵子树中的节点,到 1 1 1 的距离。
-
通过双指针的计算方法,分别计算出两棵子树贡献的错误路径数量均为:
( 3 − 1 ) + ( 3 − 2 ) = 3 (3-1) + (3-2) = 3 (3−1)+(3−2)=3 -
可得包含节点 1 1 1 的路径总数应为
21 − 3 ∗ 2 = 15 21-3*2=15 21−3∗2=15
这样按照如上流程,可以计算出所有 f ( i ) f(i) f(i)。那么该题就到此为止了?男人八题就这?
接下来讨论下时间复杂度。从上述流程中不难看出,一棵有 n n n 个节点的树,要构造 2 ∗ n 2*n 2∗n 个距离数组:
- 一个是到子树根节点的。
- 一个是到子树根节点的父节点的。
那么每棵子树包含的节点数量就很重要了。考虑树退化成链表的情况:
那么,每棵树的节点数量为:
- t r e e 1 tree_1 tree1 有 n n n 个节点
- t r e e i tree_i treei 有 n − i + 1 n-i+1 n−i+1 个节点
- t r e e n tree_n treen 有 1 1 1 个节点
那么时间复杂度就是 O ( n 2 ) O(n^2) O(n2)。。。
接下来,引入一个新知识点树的重心。当把重心当做根节点时,可以保证最大子树的也不会超过 n 2 \frac{n}{2} 2n个节点。计算过程很简单,首先确认一个节点当做根节点,不妨先选最小的节点当做根。再设 c n t i cnt_i cnti 是以 i i i 为根的子树的节点数, o t h i oth_i othi 为该子树之外的节点数量。
比如这样一棵树:
- c n t 1 = 4 , o t h 1 = 0 cnt_1 = 4, oth_1 = 0 cnt1=4,oth1=0
- c n t 2 = 3 , o t h 2 = 1 cnt_2 = 3, oth_2 = 1 cnt2=3,oth2=1
- c n t 3 = 2 , o t h 3 = 2 cnt_3 = 2, oth_3 = 2 cnt3=2,oth3=2
- c n t 4 = 1 , o t h 4 = 3 cnt_4 = 1, oth_4 = 3 cnt4=1,oth4=3
找出 m a x ( o t h i , m a x ( c n t s ) ) max(oth_i, max(cnt_s)) max(othi,max(cnts)) 最小的节点 i,该节点就是树的重心,记为 c e n t e r center center。
m a x ( c n t s ) , s ∈ i 的 子 节 点 max(cnt_s), s ∈ {i的子节点} max(cnts),s∈i的子节点,表示 i i i 的最大的子树。
o t h i oth_i othi 的意义其实是把 i i i 的当做整棵树的根时,多出来的那颗子树,这个子树的根节点,就是现在 i i i 的父节点~
如下图所示,红色部分就是节点 5 的 o t h oth oth。
这样找出来根节点 c e n t e r center center 保证了最大的子树也不会超过 n 2 \frac{n}{2} 2n。让我们把链表变长,来看一看效果。
首先找到了
c
e
n
t
e
r
=
4
center=4
center=4。
对于两棵子树,分别找到重心为 2 和 6。
不难发现,通过把重心当做根的方式,可以保证子树的规模最起码会缩减一半。这样整体的时间复杂度就讲到了 O ( n ∗ lg n ) O(n*\lg n) O(n∗lgn)。
#include <stdio.h>
#include <string.h>
#include <algorithm>
using namespace std;
struct Node {
int v, w, next;
}edge[20001];
int edge_cnt = 0;
int head[10001];
bool erase_flag[10001];
int node_count[10001];
int focus_mark[10001];
int queue[10001];
void CountNode(int root, int pre) {
node_count[root] = 1;
for (int i = head[root]; i != -1; i = edge[i].next) {
int next = edge[i].v;
if (next != pre && erase_flag[next] == false) {
CountNode(next, root);
node_count[root] += node_count[next];
}
}
}
int GetCenter(int root) {
int cand_root = -1;
int cand_root_threshold = 0x3f3f3f3f;
int l = 0, r = 0;
queue[r++] = root;
while (l < r) {
int f = queue[l++];
int max_subtree_node_num = 0;
int total_subtree_node_num = 0;
for (int i = head[f]; i != -1; i = edge[i].next) {
int next = edge[i].v;
if (focus_mark[next] != root && erase_flag[next] == false) {
queue[r++] = next;
focus_mark[next] = root;
total_subtree_node_num += node_count[next];
max_subtree_node_num = max(max_subtree_node_num, node_count[next]);
}
}
max_subtree_node_num = max(node_count[root] - total_subtree_node_num - 1, max_subtree_node_num);
if (max_subtree_node_num < cand_root_threshold) {
cand_root = f;
cand_root_threshold = max_subtree_node_num;
}
}
return cand_root;
}
void GetDist(int root, int pre, int pre_dist, int k, int *dist, int &cnt) {
if (pre_dist > k) {
return;
}
dist[cnt++] = pre_dist;
for (int i = head[root]; i != -1; i = edge[i].next) {
int next = edge[i].v;
int d = edge[i].w;
if (next != pre && erase_flag[next] == false) {
GetDist(next, root, pre_dist + d, k, dist, cnt);
}
}
}
int total_dist[10001];
int GetTotalPair(int root, int pre, int pre_dist, int k) {
int dist_cnt = 0;
GetDist(root, pre, pre_dist, k, total_dist, dist_cnt);
sort(total_dist, total_dist + dist_cnt);
int total_pair = 0;
for (int l = 0, r = dist_cnt-1; l < r;) {
if (total_dist[l] + total_dist[r] <= k) {
total_pair += r-l;
l++;
} else {
r--;
}
}
return total_pair;
}
int DivideAndConquer(int root, int k) {
// 将 root 更新为重心
CountNode(root, 0);
root = GetCenter(root);
int total_pair = GetTotalPair(root, 0, 0, k);
for (int i = head[root]; i != -1; i = edge[i].next) {
int next = edge[i].v;
int dist = edge[i].w;
if (erase_flag[next] == false) {
total_pair -= GetTotalPair(next, root, dist, k);
}
}
erase_flag[root] = true;
for (int i = head[root]; i != -1; i = edge[i].next) {
int next = edge[i].v;
if (erase_flag[next] == false) {
total_pair += DivideAndConquer(next, k);
}
}
return total_pair;
}
int main() {
int n, k;
while(scanf("%d %d", &n, &k) && (n || k)) {
edge_cnt = 0;
memset(head, -1, sizeof(int)*(n+1));
for (int i = 1, u, v, w; i < n; i++) {
scanf("%d %d %d", &u, &v, &w);
edge[edge_cnt].v = v;
edge[edge_cnt].w = w;
edge[edge_cnt].next = head[u];
head[u] = edge_cnt++;
edge[edge_cnt].v = u;
edge[edge_cnt].w = w;
edge[edge_cnt].next = head[v];
head[v] = edge_cnt++;
}
memset(erase_flag, 0, sizeof(bool)*(n+1));
memset(focus_mark, 0, sizeof(int)*(n+1));
printf("%d\n", DivideAndConquer(1, k));
}
return 0;
}