点分治,是一种针对可带权树上简单路径统计问题的算法。本质上是一种带优化的暴力,带上一点容斥的感觉。
注意对于树上路径,并不要求这棵树有根,即我们只需要对无根树进行统计。接下来请把无根树这一关键点牢记于心。
下面给出一道例题 : poj的1741
对于不合法路径,我们该如何删除呢?
我们可以这样来处理:
不妨假设当前的父亲节点是图中的点A,其遍历到了儿子B,对于不合法路径,其实就相当于要筛除以B2为父节点来计算的路径不超过K的个数 ,这是因为若以B为父节点,原先不合法的合并在B中其实就相当于是合法的了,不妨可以画图理解一下,具体来说:
比如 ABE 与 ABD是不合法的合并,这是因为 D,E在同一颗子树当中,当我们以B为根节点进行筛除的时候,其实就相当于合法的了。因此可以利用这种性质筛除不合法路径。
【算法核心】
可以看到,计算一个节点的复杂度为O(nlogn)O(nlogn),但要保证总复杂度不超过一个量级却很难。
但是我们可以利用无根树的性质!
可以看到,把一个点的答案算完时,它的子节点所代表的子树就互不影响了!
就是说,这些子树彼此独立,可以完全当作一个新的子问题处理。
那么考虑如下算法:
- 对于这棵无根树,找到一个点,使得它在树的中心位置,满足如果以它为根,它的最大子树大小尽量小,这个点称为重心。
- 以这个点为根,计算它的答案。
- 把以这个点为根的树的所有子树单独作为一个子问题,回到步骤11递归处理。
这个算法的复杂度是多少呢?
先介绍一个定理:以树的重心为根的有根树,最大子树大小不超过n2n2。
假设超过了,大小为k>n2k>n2,那么其他子树大小之和等于n−k−1n−k−1。
那么把重心往这个子树方向移动,最大子树大小一定减小,因为n−k<n2<kn−k<n2<k。
那么进一步地,就证明了经过这个算法,递归的次数是O(logn)O(logn)级别。
这样,就进一步说明了算法总时间复杂度不超过O(nlog2n)O(nlog2n)。
#include<bits/stdc++.h>
using namespace std;
const int N = 1e4 + 10;
int n,k;
int h[N],ne[N * 2],edge[N * 2],e[N * 2],idx,ans,root;
int mx_son,mi_son,dis[N],stk[N],tt,sz_son[N],sum;
bool vis[N];
void add(int a,int b,int c)
{
e[idx] = b;
edge[idx] = c;
ne[idx] = h[a];
h[a] = idx++;
}
void init()
{
idx = 0;
memset(h,-1,sizeof h);
memset(vis,0,sizeof vis);
mi_son = 0x3f3f3f3f;
sum = n;
ans = 0;
}
void dfs_root(int x,int fa)
{
sz_son[x] = 1;
for(int i = h[x] ; i!= -1 ; i = ne[i])
{
if(!vis[e[i]] && e[i] != fa)
{
dfs_root(e[i],x);
sz_son[x] += sz_son[e[i]];
mx_son = max(mx_son,sz_son[e[i]]);
}
}
int w = max(sum - sz_son[x],mx_son);
if(w < mi_son)
{
mi_son = w;
root = x;
}
}
void get_dis(int rt,int fa,int len)
{
dis[rt] = len;
stk[++tt] = dis[rt];
for(int i = h[rt]; i != - 1 ; i = ne[i])
{
if(!vis[e[i]] && e[i] != fa)
{
dis[e[i]] = edge[i] + len;
get_dis(e[i],rt,len + edge[i]);
}
}
}
int calc(int rt,int fa,int len)
{
int ans1 = 0;
tt = 0;
memset(dis,0x3f,sizeof dis);
get_dis(rt,fa,len);
sort(stk + 1,stk + 1 + tt);
int l = 1 ,r = tt;
while(l <= r)
{
if(stk[l] + stk[r] <= k)
{
ans1 += r - l;
l++;
}
else r--;
}
return ans1;
}
void dfs(int rt)
{
int t;
vis[rt] = 1;
t = calc(rt,rt,0);
ans += t;
for(int i = h[rt]; i != -1 ; i = ne[i])
{
if(!vis[e[i]])
{
t = calc(e[i],rt,edge[i]);
ans -= t;
mi_son = 0x3f3f3f3f;
sum = sz_son[e[i]];
root = 0;
mx_son = 0;
dfs_root(e[i],rt);
dfs(root);
}
}
}
int main()
{
while(cin >> n >> k &&(n||k))
{
init();
for(int i = 1 ; i < n ; i++)
{
int a,b,c;
cin >> a >> b >> c;
a++;
b++;
add(a,b,c);
add(b,a,c);
}
dfs_root(1,1);
dfs(root);
cout << ans << endl;
}
}