传送门:http://poj.org/problem?id=1741
写的第一道树分治题,撒花纪念~
对于每一对点对(i, j),它有三种情况:
① 其中一个是根节点。这种情况比较简单,直接加上就好了。
② 横跨根节点。这种情况是重点。
③ 不是以上两种情况。这时递归下去求解就好了。
那么对于第二种情况该怎么破呢?设根节点为root,那么dist(i, root) + dist(j, root) <= k,且需要i与j在不同的子树里。直接算不同子树的点对(i, j)的个数会麻烦,所以需要一点技巧:符合条件且在不同子树的(i, j)的对数 = 符合条件的对数 - 符合条件且在相同子树的(i, j)的对数,这样就搞定啦!
#include <cstdio>
#include <cstring>
#include <algorithm>
const int maxn = 10005;
int n, k, t1, t2, t3, ans;
int head[maxn], to[maxn << 1], next[maxn << 1], w[maxn << 1], lb;
int siz[maxn], a[maxn], left, right;
bool book[maxn];
inline void ist(int aa, int ss, int ww) {
to[lb] = ss;
next[lb] = head[aa];
head[aa] = lb;
w[lb] = ww;
++lb;
}
int fnd_zx(int fr, int tot_node, int p, int & rt, int & mn) {
int mx = 0;
for (int j = head[fr]; j != -1; j = next[j]) {
if (!book[to[j]] && to[j] != p) {
fnd_zx(to[j], tot_node, fr, rt, mn);
mx = std::max(mx, siz[to[j]]);
}
}
mx = std::max(mx, tot_node - siz[fr]);
if (mn > mx) {
mn = mx;
rt = fr;
}
}
void get_siz(int fr, int p) {
siz[fr] = 1;
for (int j = head[fr]; j != -1; j = next[j]) {
if (!book[to[j]] && to[j] != p) {
get_siz(to[j], fr);
siz[fr] += siz[to[j]];
}
}
}
void get_data(int r, int p, int ww) {
if (ww > k) {
return;
}
a[right++] = ww;
for (int j = head[r]; j != -1; j = next[j]) {
if (!book[to[j]] && to[j] != p) {
get_data(to[j], r, ww + w[j]);
}
}
}
int get_ans(int l, int r) {
std::sort(a + l, a + r);
int rt = 0;
--r;
while (r > l) {
while (r > l && a[l] + a[r] > k) {
--r;
}
rt += r - l;
++l;
}
return rt;
}
void slove(int fr) {
int root = -666, mn = 2147483647;
get_siz(fr, 0);
fnd_zx(fr, siz[fr], 0, root, mn);
book[root] = true;
for (int j = head[root]; j != -1; j = next[j]) {
if (!book[to[j]]) {
slove(to[j]);
}
}
left = right = 0;
for (int j = head[root]; j != -1; j = next[j]) {
if (!book[to[j]]) {
get_data(to[j], root, w[j]);
ans -= get_ans(left, right);
left = right;
}
}
ans += get_ans(0, right) + right;
book[root] = false;
}
int main(void) {
//freopen("in.txt", "r", stdin);
while (scanf("%d%d", &n, &k) && n && k) {
lb = 0;
memset(head, -1, sizeof head);
memset(next, -1, sizeof next);
ans = 0;
for (int i = 1; i < n; ++i) {
scanf("%d%d%d", &t1, &t2, &t3);
ist(t1, t2, t3);
ist(t2, t1, t3);
}
slove(1);
printf("%d\n", ans);
}
return 0;
}