点分治用于解决与树上路径有关的统计问题。
点分治的思想类似一般的分治:对于一棵树,一方面处理跨越不同子树的路径,即经过根的路径(相当于分治后合并时考虑跨越不同区间的影响),另一方面递归到各个子树中处理子树内的路径(相当于分治中考虑区间内的问题)。
而实际处理中,我们希望递归层数尽量小。因此,我们在处理一棵无根树时,选取的根要满足:其最大子树的大小尽量小。这个根称为重心。
总结起来,点分治流程如下:
1.找重心。
怎么找?
dfs一遍,计算每个点为根时其各个子树的大小,比较其最大子树的大小与之前找到最优点的最大子树的大小。如果更小,则更新最优点。
代码:
void getroot(int u,int fa){
size[u]=1,F[u]=0;//size[u]为u为根的树的大小,F[u]为u为根的树的最大子树大小
for(int i=head[u];i!=-1;i=G[i].nxt){
int v=G[i].to;
if(v==fa||vis[v]) continue;
getroot(v,u);
size[u]+=size[v];
F[u]=max(F[u],size[v]);
}
F[u]=max(F[u],sum-size[u]);//sum为当前整棵树的大小,这里是考虑以u的父亲为根的子树
if(F[u]<F[root]) root=u;
}
2.处理经过根的路径。要注意的是,我们处理时会遇到路径两端在同一子树内的路径,这种路径统计时需要删去(容斥原理)。
3.标记根节点(相当于删去)。
4.递归处理各个子树。
以本题为例:
1.找重心。
2.dfs求出当前树中各个结点到根的距离,然后将所有距离排序,统计符合条件的点对数,具体方法如下:
设由小到大排好序的所有距离为Q[i],两个指针l和r,初始时l+1为在开头,r为末尾。
如果
Q[l+1]+Q[r]≤k
Q
[
l
+
1
]
+
Q
[
r
]
≤
k
,那么对于一端为结点l+1的路径,另一端为编号小于等于r且不等于l+1的结点都符合要求,而其中小于l+1的部分都已经统计过了,所以
ans+=r−l−1
a
n
s
+
=
r
−
l
−
1
,之后
l++
l
+
+
;
否则,
r−−
r
−
−
,直到满足上述要求。
但这样得到的还不是正确答案。我们将所有结点放在一起统计,就可能存在路径的两端都来自同一子树的情况,如下图:
在上述算法中,如果路径3——2——1——2——4的长度小于等于k,那么它就会被我们统计到。显然,这种路径是不符合要求的。如何将其删去?
我们可以在以2号点为根的子树中进行一次上述统计,其中计算所有点到根2号点的距离时都加上边1——2的长度。这样我们计算出的距离实际是各个点到1号点的距离,因此统计出的路径也是前面提到的需要删去的路径。对各个子树都进行一遍这样的操作,从答案中减去即可。
3.标记根节点(相当于删去)。
4.递归处理各个子树。
讨论时间复杂度:
可以发现,递归层数最多为 logn l o g n ,因为每次找到重心后,最大子树一定不超过整棵树大小的一半。每层递归中dfs一遍所有子树为 O(n) O ( n ) ,排序为 O(nlogn) O ( n l o g n ) 。因此总时间复杂度 O(nlog2n) O ( n l o g 2 n ) 。
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int maxn=40010;
int n,k,ans;
int root,sum,size[maxn],F[maxn],dis[maxn];
bool vis[maxn];
int Q[maxn],l,r;
int head[maxn],Ecnt;
struct edge{
int to,nxt,w;
}G[maxn*2];
inline void addE(int u,int v,int w){
G[Ecnt]=(edge){v,head[u],w};
head[u]=Ecnt++;
G[Ecnt]=(edge){u,head[v],w};
head[v]=Ecnt++;
}
void getroot(int u,int fa){
size[u]=1,F[u]=0;
for(int i=head[u];i!=-1;i=G[i].nxt){
int v=G[i].to;
if(v==fa||vis[v]) continue;
getroot(v,u);
size[u]+=size[v];
F[u]=max(F[u],size[v]);
}
F[u]=max(F[u],sum-size[u]);
if(F[u]<F[root]) root=u;
}
void getdis(int u,int fa){
Q[++r]=dis[u];
for(int i=head[u];i!=-1;i=G[i].nxt){
int v=G[i].to;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+G[i].w;
getdis(v,u);
}
}
int calc(int u,int d){
int ret=0;
l=r=0;
dis[u]=d;
getdis(u,0);
sort(Q+1,Q+r+1);
while(l<r){
if(Q[l+1]+Q[r]<=k) ret+=r-l-1,++l;
else --r;
}
return ret;
}
void solve(int u){
ans+=calc(u,0);
vis[u]=1;
for(int i=head[u];i!=-1;i=G[i].nxt){
int v=G[i].to;
if(vis[v]) continue;
ans-=calc(v,G[i].w);
sum=size[v],root=0;
getroot(v,0);
solve(root);
}
}
int main(){
memset(head,-1,sizeof(head));
scanf("%d",&n);
for(int i=1;i<n;i++){
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
addE(u,v,w);
}
scanf("%d",&k);
sum=F[0]=n,root=0;
getroot(1,0);
solve(root);
printf("%d\n",ans);
return 0;
}