Description
给定一棵
n
个节点的带权树,称距离不超过
n⩽10000 ,多组数据。
Solution
如果使用普通的DFS遍历,时间复杂度为
O(n2)
.无法接受.
考虑使用点分治,对于根节点u,路径可以分为经过点u和不经过的点u的路径.对于不经过点u的路径,递归处理子树即可.
接下来分析如何处理经过点u的路径:
记
depi
表示节点
i
到根节点u的距离.
则所求为满足
depi+depj⩽k
且
beli≠belj
的点对
(i,j)
个数.
即为满足
depi+depj⩽k
的点对数减去满足
depi+depj⩽k
且
beli=belj
的点对数.
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
using namespace std;
const int maxn = 10005;
struct edge {
int to, next, w;
}e[maxn * 2];
int n, k, h[maxn], ans, tot;
inline void add(int u, int v,int w)
{
e[++tot] = (edge) {v, h[u], w};
h[u] = tot;
}
inline int gi()
{
char c = getchar();
while(c < '0' || c > '9') c = getchar();
int sum = 0;
while('0' <= c && c <= '9') sum = sum * 10 + c - 48, c = getchar();
return sum;
}
int siz[maxn], sum, cent, Mxsz[maxn];
bool vis[maxn];
void getroot(int u, int fa)
{
siz[u] = 1; Mxsz[u] = 0;
for(int i = h[u], v; i; i = e[i].next)
if((v = e[i].to) != fa && !vis[v]) {
getroot(v, u); siz[u] += siz[v];
Mxsz[u] = max(Mxsz[u], siz[v]);
}
Mxsz[u] = max(Mxsz[u], sum - siz[u]);
if(Mxsz[u] < Mxsz[cent]) cent = u;
}
int que[maxn], cnt, dep[maxn];
void dfs(int u, int fa)
{
que[++cnt] = dep[u];
for(int i = h[u], v; i; i = e[i].next)
if((v = e[i].to) != fa && !vis[v]) dep[v] = dep[u] + e[i].w, dfs(v, u);
}
inline bool cmp(const int &a, const int &b)
{
return a < b;
}
int calc(int u, int w)
{
cnt = 0; dep[u] = w; dfs(u, 0);
sort(que + 1, que + cnt + 1, cmp);
int ret = 0, l, r;
for(l = 1, r = cnt; l < r; )
if(que[l] + que[r] <= k) ret += r - l, ++l;
else --r;
return ret;
}
int solve(int u)
{
cent = 0; getroot(u, 0); u = cent; vis[u] = true;
ans += calc(u, 0);
for(int i = h[u], v; i; i = e[i].next)
if(!vis[v = e[i].to]) {
ans -= calc(v, e[i].w);
sum = siz[v]; solve(v);
}
}
int main()
{
freopen("tree.in", "r", stdin);
freopen("tree.out", "w", stdout);
while(1) {
n = sum = gi(); k = gi();
if(!n) break;
tot = ans = 0;
memset(h, 0, sizeof(int) * (n + 1));
memset(vis, 0, sizeof(bool) * (n + 1));
for(int u, v, w, i = 1; i < n; ++i) {
u = gi(); v = gi(); w = gi();
add(u, v, w); add(v, u, w);
}
Mxsz[0] = n + 1; solve(1);
printf("%d\n", ans);
}
return 0;
}