题目描述
给你 N N 个点的边带权的树,求满足 dist(i,j)<=K d i s t ( i , j ) <= K 的合法点对数目。
数据范围
N<=10000
N
<=
10000
K<=231−1
K
<=
2
31
−
1
分析
一直都没有认真地学习点分治,终于有时间好好来弄一下。
1.首先分析一下,如果是直接用树形dp来做的话,用
f[u][j]
f
[
u
]
[
j
]
表示到根
u
u
的距离为的点有多少个,并且确保每个点对统计的时候只能被他们最近公共祖先所统计到,时间复杂度为
O(NK)
O
(
N
K
)
,并不能达到预期的结果。
2.于是我们考虑用点分来解决这道题,点分的核心就是维护每次切分子树都选择树的重心来进行分治,确保每次树高都趋近于
O(log(size[u]))
O
(
l
o
g
(
s
i
z
e
[
u
]
)
)
,这样问题的复杂度就达到了
O(Nlog2(N))
O
(
N
l
o
g
2
(
N
)
)
。
3.下面就是计算答案的过程:
(I)
(
I
)
如果路径在子树中,那么就可以直接分治处理。
(II)
(
I
I
)
如果路径经过根,记
dist[i]
d
i
s
t
[
i
]
表示
i
i
节点到根的距离,满足且
i−>j
i
−
>
j
的路径经过根的合法对数目=满足
dist[i]+dist[j]<=K
d
i
s
t
[
i
]
+
d
i
s
t
[
j
]
<=
K
的合法对数-满足
dist[i]+dist[j]<=K
d
i
s
t
[
i
]
+
d
i
s
t
[
j
]
<=
K
且在
i,j
i
,
j
均在根节点的一棵子树中的合法对数。因为
Ai
A
i
满足单调性,所以可以
O(NlogN)
O
(
N
l
o
g
N
)
排序之后
O(N)
O
(
N
)
扫描一遍直接得出答案。
PS:poj的头文件真的#$%@#!@?
#include <bits/stdc++.h>
//#include <iostream>
//#include <cstdlib>
//#include <algorithm>
//#include <cstring>
//#include <cstdio>
#define rep( i , l , r ) for( int i = (l) ; i <= (r) ; ++i )
#define per( i , r , l ) for( int i = (r) ; i >= (l) ; --i )
#define erep( i , u ) for( int i = head[(u)] ; ~i ; i = e[i].nxt )
using namespace std;
int _read(){
char ch = getchar();
int x = 0 , f = 1 ;
while( !isdigit( ch ) )
if( ch == '-' ) f = -1 , ch = getchar();
else ch = getchar();
while( isdigit( ch ) )
x = (ch - '0') + x * 10 , ch = getchar();
return x * f;
}
const int maxn = 10000 + 5;
struct edge{
int v , w , nxt;
} e[maxn * 3];
int head[maxn] , _t = 0;
inline void addedge( int u , int v , int w ){
e[_t].v = v , e[_t].w = w , e[_t].nxt = head[u] , head[u] = _t++;
e[_t].v = u , e[_t].w = w , e[_t].nxt = head[v] , head[v] = _t++;
}
bool vis[maxn];
int sz[maxn] , mx[maxn] , size = 0 , N , K , root = 0 , _min = 0;
void dfs_core( int u , int f ){
sz[u] = 1 , mx[u] = 0;
erep( i , u ){
int v = e[i].v;
if( v == f || vis[v] ) continue;
dfs_core( v , u );
sz[u] += sz[v] , mx[u] = max( sz[v] , mx[u] );
}
mx[u] = max( mx[u] , size - sz[u] );
if( mx[u] < mx[root] ) root = u ;
}
int dis[maxn] , cnt = 0;
void dfs_dis( int u , int f , int w ){
dis[cnt++] = w;
erep( i , u ){
int v = e[i].v;
if( v == f || vis[v] ) continue;
dfs_dis( v , u , e[i].w + w );
}
}
int calc( int u , int w ){
int ret = 0; cnt = 0;
dfs_dis( u , 0 , w );
sort( dis, dis + cnt );
int l = 0 , r = cnt - 1;
while( l < r ){
while( dis[l] + dis[r] > K && l < r ) r--;
ret += (r - l) ; l++;
}
return ret;
}
int ans = 0;
void dfs( int u ){
ans += calc( u , 0 );
vis[u] = 1;
erep( i , u ){
int v = e[i].v;
if( !vis[v] ){
ans -= calc( v , e[i].w );
mx[0] = N , size = sz[u];
dfs_core( v , root = 0 );
dfs( root );
}
}
}
int main(){
while( scanf("%d %d" , &N , &K) != EOF){
if( !N && ! K ) break;
memset( head , 0xff , sizeof head );
memset( vis , 0 , sizeof vis );
_t = ans = 0;
int u , v , w;
rep( i , 1 , N - 1 ) {
scanf("%d %d %d" , &u , &v , &w );
addedge( u , v , w );
}
mx[0] = size = N;
ans = 0 ; dfs_core( 1 , root = 0 );
dfs( 1 );
cout << ans << endl;
}
return 0;
}