点分治
树的点分治,是在树中找一个点,把它砍掉后,树就变成了一个森林,然后分别处理这个森林中的每一棵树,统计答案。显然,如果你砍掉叶子结点,这个分治就没有意义了。所以我们要找一个点把树尽可能地平均分,这个点叫树的重心。所谓平均分,就是:这个点的最大子树(包括父亲那边的一堆)的大小(
MaxS
M
a
x
S
)要最小,例如:
此时点
1
1
的就是
8
8
,再如:
此时点的
MaxS
M
a
x
S
就是
9
9
。
树的重心就是一个最小的点,一棵树可能有多个重心,而这个树的一个重心是 2 2 ,它的是 7 7 。
求树的重心
思路
按照递归的方式,处理出结点的每个子树的大小 Size[v] S i z e [ v ] ,注意不要往父亲方向上走,即使重心需要知道父亲方向的 Size S i z e (如第二个图的 Size=9 S i z e = 9 ),如果往父亲方向走就无限递归了。如何知道父亲方向的 Size S i z e 呢?很简单: N−Size[u] N − S i z e [ u ] ,即总的减掉 u u 为根子树的大小。对这些取最大值就是结点 u u 的,再对每个结点的 MaxS M a x S 取最小值,找到的即为重心。
模板题
模板题大意
给你一棵无根树, n n 个结点条边,第一行输出两个整数,分别代表树的重心的最大子树的大小(子树大小的定义前面已经说过)和重心的个数,第二行按升序输出所有重心。
代码
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
int read(){
int x=0;char c=getchar();
while(c<'0'||c>'9') c=getchar();
while(c>='0'&&c<='9') x=x*10+c-'0',c=getchar();
return x;
}
#define MAXN 16000
#define INF 0x3f3f3f3f
int N,Min=INF;
int Size[MAXN+5];
vector<int> Ans;
vector<int> G[MAXN+5];
void dfs(int rt,int f){
int tmp=0;//tmp就是这个点的MaxS
Size[rt]=1;//一开始之一u这一个结点
for(int i=0;i<int(G[rt].size());i++){
int v=G[rt][i];
if(v!=f){//判断不走回头路
dfs(v,rt);
Size[rt]+=Size[v];//加上以v为根的子树的大小
tmp=max(tmp,Size[v]);//取最大值
}
}
tmp=max(tmp,N-Size[rt]);//找到父亲那边的大小
if(tmp<Min){
Min=tmp;
Ans.clear();
}
if(tmp==Min)//需要找所有重心
Ans.push_back(rt);//统计答案
}
int main(){
N=read();
for(int i=1;i<N;i++){
int u=read(),v=read();
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,-1);
sort(Ans.begin(),Ans.end());//注意sort
printf("%d %d\n",Min,Ans.size());
for(int i=0;i<int(Ans.size());i++)
printf("%d ",Ans[i]);
}
典型例题
题目
POJ 1741Tree
题目大意
给你一棵边带权的 n n 个结点的树,问你这个树上的长度不大于的路径有多少条。
思路
既然是点分治,就要找到重心,找到重心后统计通过重心的路径中满足题意的有多少条,然后递归子树重复上诉操作,注意判重,即在找通过重心 v v 的路径时不能通过已经访问过的重心。
算法流程:
- 找到以 u u 为根的子树的重心,标记 C C 。
- 统计通过的路径中有多少条长度不大于 k k ,这些路径上的点都不能是被标记了的点。
- 递归的子节点 v v ,不能被标记过,以 v v 为根,重复上诉操作。
难点是第步,如何统计这些路径。可以把
u→v
u
→
v
经过
C
C
看成加上
v→C
v
→
C
。
设
dist[i]
d
i
s
t
[
i
]
表示深搜是访问到的第
i
i
个点(注意不是原始数据中的点的编号,详见代码)到的路径长度(用深搜找到),则需要找
dist[i]+dist[j]≤k
d
i
s
t
[
i
]
+
d
i
s
t
[
j
]
≤
k
的数量即可。于是很多大佬开始秀自己的平衡树了……蒟蒻只能用排序+滑窗来做。将
dist
d
i
s
t
排序后:
- 如果
dist[i]+dist[j]≤k
d
i
s
t
[
i
]
+
d
i
s
t
[
j
]
≤
k
,那么
i→i+1
i
→
i
+
1
、
i→i+2
i
→
i
+
2
……
i→j
i
→
j
这
j−i
j
−
i
条路径都符合条件,所以
Ans+=j-i
,同时,如果 j j 是倒着枚举的,那么以开始的路径不可能有更多了,所以i++
。 - 否则:就说明这个区间取大了,
j--
。 - 换句话说,就是枚举 i i 为起点的路径有多少条,是用来辅助的。
然而有个问题:
这个图中,假设边权都是
1
1
、,当以
2
2
为重心找的时候:到
2
2
的长度是,
11
11
到
2
2
的长度是。按照之前的说法,
9
9
到的长度是
4
4
,但是显然到
11
11
的简单
路径长度是
2
2
,原因是走了两遍这条边。
所以,一旦有两个点在 u u 的同一个儿子的子树中,这两个点构成的路径就是不合法的,因为你强行让他们的路径经过 u u ,而实际上他们的路径只需要经过。换句话说, u→v u → v 走了两遍。
我们只需要把这种情况的个数减去就可以了,详情见代码(这个我也不好解释)。
代码
#include<set>
#include<cstdio>
#include<vector>
#include<cstring>
#include<algorithm>
using namespace std;
int read(){
int x=0,f=1;char c=getchar();
while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
return x*f;
}
#define MAXN 10000
#define INF 0x3f3f3f3f
struct Edge{
int v,w;
Edge(){}
Edge(int a,int b){
v=a,w=b;
}
};
int N,K,Ans;
vector<Edge> G[MAXN+5];
bool vis[MAXN+5];
int Size[MAXN+5],MaxS[MAXN+5],Center;
//Center是重心
//其实不用开MaxS这个数组
void Find(int u,int fa,int All){//找重心
Size[u]=1,MaxS[u]=0;
for(int i=0;i<int(G[u].size());i++){
int v=G[u][i].v;
if(v!=fa&&!vis[v]){
Find(v,u,All);
Size[u]+=Size[v];
MaxS[u]=max(MaxS[u],Size[v]);
}
}
MaxS[u]=max(MaxS[u],All-Size[u]);
if(MaxS[u]<MaxS[Center])
Center=u;
}
int dist[MAXN+5],cnt;
void dfs(int u,int fa,int dep){
dist[++cnt]=dep;//记录路径长度
for(int i=0;i<int(G[u].size());i++){
int v=G[u][i].v,w=G[u][i].w;
if(v!=fa&&!vis[v])
dfs(v,u,dep+w);
}
}
int Cal(int u,int backW){
//backW是为了保持儿子更新出来的dist跟自己更新出来的dist一样
cnt=0;
dfs(u,-1,backW);
sort(dist+1,dist+cnt+1);//注意一定要排序
int Left=1,Right=cnt,ret=0;
while(Left<Right){//滑窗
if(dist[Left]+dist[Right]<=K)
ret+=Right-Left,Left++;
else
Right--;
}
return ret;
}
void Solve(int u){
vis[u]=1;
Ans+=Cal(u,0);//先算u
for(int i=0;i<int(G[u].size());i++){
int v=G[u][i].v,w=G[u][i].w;
if(!vis[v]){
Ans-=Cal(v,w);
//减掉经过v的,读者可以自己画图分析一下,以下是我的理解
//由于v有一条backW,所以Cal(v,w)里面统计出来的每一条路径都加了2*backW
//即每条路径都重复走了边u->v
//把这种情况减去
Center=0;//找子树重心,这里与上述算法流程稍有不同,本质是一样的
Find(v,-1,Size[v]);
Solve(Center);
}
}
}
int main(){
while(1){
N=read(),K=read();
if(!N&&!K)
return 0;
for(int i=1;i<N;i++){
int u=read(),v=read(),w=read();
G[u].push_back(Edge(v,w));
G[v].push_back(Edge(u,w));
}
//for(int i=0;i<int(G[1].size());i++)
// printf("%d ",G[1][i].v);
Size[1]=N;
MaxS[0]=INF;
Center=0;
Find(1,-1,N);
Solve(Center);
printf("%d\n",Ans);
Ans=0;
for(int i=1;i<=N;i++)
G[i].clear();
memset(vis,0,sizeof vis);
}
}