0x45
点分治
到目前为止,我们用数据结构处理的大多是序列上的问题。这些问题的形式一般是给定序列中的两个位置 l l l和 r r r,在区间 [ l , r ] [l,r] [l,r]上执行查询或修改指令。如果给定一棵树,以及树上两个节点 x x x和 y y y,那么与“序列上的区间”相对应的就是“树上两点之间的路径”。我们先不考虑对路径进行修改的的操作。本节中介绍的点分治就是在一棵树上,对具有某些限定条件的路径静态进行统计的算法。
点分治是一种解决树上统计问题的常用方法,本质思想就是选择一点(重心)作为分治中心,将原问题划分为几个相同的子树上的问题,进行递归解决。
给一颗有 N N N个点的树,每条边都有一个权值。树上两个节点 x x x和 y y y之间的路径长度就是路径上各条边的权值之和。求长度不超过 K K K的路径有多少条。
本题中的边是无向的,即这棵树是一个由 N N N个点、 N − 1 N-1 N−1条边构成的无向连通图。我们把这种树称为“无根树”(所需维护的信息与根节点是谁无关),也就是说可以任意指定一个节点为根节点,而不影响问题的答案。
若指定节点 p p p为根,则对 p p p而言,树上的路径可以分为两类:
1.经过根节点 p p p(包含一端为根节点 p p p)。
2.包含于 p p p的某一棵子树中(不经过根节点)。
根据分治的思想,对于第2类路径,显然可以把 p p p的每棵子树作为子问题,递归进行处理。
而对于第1类路径,可以从根节点
p
p
p分成“
x
∼
p
x\sim p
x∼p”与“
p
∼
y
p\sim y
p∼y”两段。回顾在0x21
节所学到的知识,我们可以从
p
p
p出发对整棵树进行DFS
,求出数组
d
d
d,其中
d
[
x
]
d[x]
d[x]表示点
x
x
x到根节点
p
p
p的距离。同时还可以求出数组
b
b
b,其中
b
[
x
]
b[x]
b[x]表示点
x
x
x属于根节点
p
p
p的哪一棵子树,特别的,令
b
[
p
]
=
p
b[p]=p
b[p]=p。
此时满足题目要求的第1类路径满足以下两个条件的点对 ( x , y ) (x,y) (x,y)的个数:
1. b [ x ] ≠ b [ y ] b[x]\neq b[y] b[x]=b[y]。
2. d [ x ] + d [ y ] ≤ K d[x]+d[y]\leq K d[x]+d[y]≤K。如下图所示。
定义 C a l ( p ) Cal(p) Cal(p)表示在以 p p p为根的树中统计上述点对的个数(第1类路径的条数)。 C a l ( p ) Cal(p) Cal(p)有两种常见的实现方式。针对不同的题目,二者各有优劣。
方法一:树上直接统计
设 p p p的子树为 s 1 , s 2 , . . . , s m s_1,s_2,...,s_m s1,s2,...,sm。
对于 s i s_i si中每个节点 x x x,把在子树 s 1 , s 2 , . . . , s i − 1 s_1,s_2,...,s_{i-1} s1,s2,...,si−1中满足 d [ x ] + d [ y ] ≤ K d[x]+d[y]\leq K d[x]+d[y]≤K的节点 y y y的个数累加到答案中即可。
具体来说,可以建立一个树状数组,依次处理每棵子树 s i s_i si。
1.对于 s i s_i si中的每个节点 x x x,查询前缀和 a s k ( K − d [ x ] ) ask(K-d[x]) ask(K−d[x]),即为所求的 y y y的个数。
2.对于 s i s_i si中的每个节点 x x x,执行 a d d ( d [ x ] , 1 ) add(d[x],1) add(d[x],1),表示与 p p p距离为 d [ x ] d[x] d[x]的节点增加了1个。
按子树一棵棵进行处理保证了 b [ x ] ≠ b [ y ] b[x]\neq b[y] b[x]=b[y],查询前缀和保证了 d [ x ] + d [ y ] ≤ K d[x]+d[y]\leq K d[x]+d[y]≤K。
需要注意的是,树状数组的范围与路径长度有关,这个范围远比 N N N要大。而本题中不易进行离散化。一种解决方案是用平衡树代替树状数组,以保证 O ( N l o g N ) O(NlogN) O(NlogN)的复杂度,但代码复杂度显著增加。所以本题更适用下一种方法。
方法二:指针扫描数组
把树中每个点放进一个数组 a a a,并把数组 a a a按照节点的 d d d值排序。
使用两个指针 L , R L,R L,R分别从前、后开始扫描 a a a数组。
容易发现,在指针 L L L从左往右扫描的过程中,恰好使得 d [ a [ L ] ] + d [ a [ R ] ] ≤ K d[a[L]]+d[a[R]]\leq K d[a[L]]+d[a[R]]≤K的指针 R R R的范围是从右往左单调递减的。
另外,我们用数组 c n t [ s ] cnt[s] cnt[s]维护在 L + 1 L+1 L+1与 R R R之间满足 b [ a [ i ] ] = s b[a[i]]=s b[a[i]]=s的位置 i i i的个数。
于是,当路径的一端 x x x等于 a [ L ] a[L] a[L]时,满足题目要求的路径另一端 y y y的个数就是 R − L − c n t [ b [ a [ L ] ] ] R-L-cnt[b[a[L]]] R−L−cnt[b[a[L]]]。
总而言之,整个点分治算法的过程就是:
1.任选一个根节点 p p p(后面我们将说明, p p p应该取树的重心)。
2.从
p
p
p出发进行一次DFS
,求出
d
d
d数组和
b
b
b数组。
3.执行 C a l ( p ) Cal(p) Cal(p)。
4.删除根节点 p p p,对 p p p的每棵子树(看作无根树)递归执行1~4步。
在点分治过程中,每一层的所有递归过程合计对每个点处理1次。因此,若递归最深处到达第 T T T层,整个算法的时间复杂度为 O ( T N l o g N ) O(TNlogN) O(TNlogN)。
如果问题中的树是一条链,最坏情况下每次都以链的一端为根,那么点分治将需要递归
N
N
N层,时间复杂度退化到
O
(
N
2
l
o
g
N
)
O(N^2logN)
O(N2logN)。为了避免这种情况,我们每次选择树的重心(曾在0x21
节提及)作为根节点
p
p
p。对于树上的每一个点,计算其所有子树中最大的子树节点数,这个值最小的点就是这棵树的重心。而不难证明树的重心具有以下性质:以树的重心为根时,所有子树的大小都不超过整棵树大小的一半。
点分治就至多递归 O ( l o g N ) O(logN) O(logN)层,算法的时间复杂度为 O ( N l o g 2 N ) O(Nlog^2N) O(Nlog2N)。如下图所示。
#include <bits/stdc++.h>
using namespace std;
const int SIZE=1e4+5;
int N,K,tot,w,sum,cnt,ans;
int ver[SIZE*2],edge[SIZE*2],nex[SIZE*2],head[SIZE];
int max_part[SIZE],siz[SIZE],dis[SIZE],root[SIZE],rec[SIZE],point[SIZE];
bool del[SIZE];
inline int read()
{
int x=0,f=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') f=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9')
x=x*10+ch-'0',ch=getchar();
return x*f;
}
inline void add(int x,int y,int z)
{
ver[++tot]=y,edge[tot]=z;
nex[tot]=head[x],head[x]=tot;
}
void dfs_w(int x,int fa)
{
siz[x]=1,max_part[x]=0;
for(int i=head[x];i;i=nex[i])
{
int y=ver[i];
if(y==fa||del[y]) continue;
dfs_w(y,x);
siz[x]+=siz[y];
max_part[x]=max(max_part[x],siz[y]);
}
max_part[x]=max(max_part[x],sum-siz[x]);
if(max_part[x]<max_part[w])
w=x;
}
void dfs(int x,int fa)
{
point[++cnt]=x,siz[x]=1;
for(int i=head[x];i;i=nex[i])
{
int y=ver[i],z=edge[i];
if(y==fa||del[y]) continue;
if(x==w) root[y]=y;
else root[y]=root[x];
rec[root[y]]++;
dis[y]=dis[x]+z;
dfs(y,x);
siz[x]+=siz[y];
}
}
void solve(int x,int fa)
{
dfs_w(x,fa);
dis[w]=0;
root[w]=w;
rec[w]=1;
for(int i=head[w];i;i=nex[i])
{
int y=ver[i];
rec[y]=0;
}
cnt=0;
dfs(w,0);
sort(point+1,point+cnt+1,[](int x,int y){
return dis[x]<dis[y];
});
int L=1,R=cnt;
rec[root[point[L]]]--;
while(L<R)
{
if(dis[point[L]]+dis[point[R]]>K)
{
rec[root[point[R]]]--;
R--;
}
else
{
ans+=R-L-rec[root[point[L]]];
L++;
rec[root[point[L]]]--;
}
}
del[w]=true;
for(int i=head[w];i;i=nex[i])
{
int y=ver[i];
if(y==fa||del[y]) continue;
sum=siz[y],w=0,max_part[0]=0x3f3f3f3f;
solve(y,w);
}
}
int main()
{
N=read();K=read();
while(N||K)
{
tot=0;
for(int i=1;i<=N;++i) head[i]=0,del[i]=false;
int x,y,z;
for(int i=1;i<N;++i)
{
x=read();y=read();z=read();
x++,y++;
add(x,y,z);
add(y,x,z);
}
ans=0;
sum=N,w=0,max_part[0]=0x3f3f3f3f;
solve(1,0);
printf("%d\n",ans);
N=read();K=read();
}
return 0;
}