点分治
Purpose
利用 \(nlogn\) 的时间复杂度巧妙地处理出很多树上路径统计问题。
Solve
我们如果选定一个点作为整棵树的根,那么树上的路径必定可以大致区分为以下两种:
- 经过根节点的
- 不经过根节点的
那么,我们只处理第一种路径,对于第二种,对于它经过的每个点,如果以这些点中的一个作为根节点,那么相对于这些点为根的时候,第二种路径肯定迟早会被处理。
而点分治的做法,则是选定一个点作为根节点,先统计所有经过此点的路径,然后分别递归下去处理它的每颗子树,在每颗子树中选出根节点,然后统计子树中经过子树根节点的路径,再递归处理子树根节点的子树。
那么直到处理的子树大小为 \(1\) 的时候,就到达了边界。
以上就是分治过程。
那么具体怎么做呢?我们分步骤考虑:
- 处理经过当前根节点的路径。
首先,如果一条路径经过当前根节点,那么它的两端必定在根节点的两个不同的子树中。
我们先一个个子树地统计每条路径到根节点的情况并记录下来,对于当前子树,我们统计完路径情况之后先不急着存入总情况内,由于总情况已经存入了前面访问过的所有子树,我们先将当前子树内所有情况与当前总情况一一配对,再将当前子树统计信息存入总情况内即可做到每颗子树内的配对无遗漏。
- 递归处理
对于根节点的每颗子树我们都要通过递归往下处理,但是每次的根节点怎么选择呢?难道乱选不成?
我们首先观察一下整个算法,我们发现,我们处理当前根节点路径信息的时间复杂度是 \(O(n)\) 的,同时,在递归处理的过程中,由于子树的大小加起来也恰好等于当前处理的树大小,所以也是 \(O(n)\) 的。那么不难得出,点分治算法的时间复杂度为 \(n×\)递归层数 。
而我们选择根节点,自然也要本着减少递归层数的原则来。那么如何使层数最小?每次选择当前树的重心即可,重心即为当前树内最大子树最小的节点。
接着我们如何寻找重心?我们只需要记录下每个节点的最大子树取 \(min\) 即可。最大子树的统计,只需要先在当前节点下所有子树的 \(size\) 取 \(max\) ,然后因为重心是在无根意义下的,所以我们同时还要反过来统计除了当前节点以外的节点数,即 \(Size_{whole}-Size_{now}\) ,与所有子树 \(size\) 取 \(max\) 即可。
由于每次根节点取的是重心,可以证明递归层数不超过 \(log\) 层。
我想到这里已经够清楚了,如果还不清楚,可以结合具体题目来看。
Example--聪聪可可
Description
给出 \(n\) 个节点的树,每次等概率选择两个树上节点 (可以相同) ,问有多大概率选择的节点之间的路径长度为 \(3\) 的倍数。输出分数形式即可,需约分。
\(n\le 2*10^4\)
Solution
点分治板子题,十分浅显易懂。
首先观察到题目和概率无关,答案就是输出 长度为 \(3\) 的倍数的路径数 / \(n^2\) ,输出的时候记得除个 \(gcd\) 就好。
那么怎么做呢?
简单!按照点分治的套路,先找出整颗子树的重心,然后以其为根建树,然后依次处理每颗子树。
\(sum[i]\) 记录之前所有子树中和根节点的距离 \(mod\ 3=i\) 的节点个数,\(dis[i]\) 则是记录当前子树中距离 \(mod\ 3=i\) 的节点个数,那么对于当前子树与之前的子树配对贡献情况自然如下:
\[ ans+=(dis[0]*sum[0]+dis[1]*sum[2]+dis[2]*sum[1])*2 \]
余数之和 \(mod\ 3=0\) 的配对想来很好懂,直接乘起来就行,但是由于等概率选择两次,先选这个再选那个反过来又是一种方案,所以需要 \(×2\) 。
还有就是注意这样统计并没有统计每个节点到自己的路径,所以最后还要 \(ans+=n\) 。这里不用 \(×2\) 了,因为确实只有一种情况。
Code
#include<iostream>
#include<cstdio>
using namespace std;
int n,head[20001],nx[40001],to[40001],w[40001];
int rt,sum,size[20001],maxs[20001],dis[20001],num[3],now[3];
int ans;
bool vis[20001];
void addroad(int u,int v,int W,int d)
{
to[d]=v,w[d]=W,nx[d]=head[u];
head[u]=d;
}
void getrt(int x,int fa)
{
size[x]=1,maxs[x]=0;
for(int i=head[x];i;i=nx[i])
if(to[i]!=fa&&!vis[to[i]])
{
getrt(to[i],x);
size[x]+=size[to[i]];
maxs[x]=max(maxs[x],size[to[i]]);
}
maxs[x]=max(maxs[x],sum-size[x]);
if(maxs[x]<maxs[rt])
rt=x;
}
void getdis(int x,int fa)
{
for(int i=head[x];i;i=nx[i])
if(to[i]!=fa&&!vis[to[i]])
{
dis[to[i]]=(dis[x]+w[i])%3;
now[dis[to[i]]]++;
getdis(to[i],x);
}
}
void work(int x)
{
for(int i=0;i<3;i++)num[i]=0;
num[0]=1;
for(int i=head[x];i;i=nx[i])
if(!vis[to[i]])
{
for(int i=0;i<3;i++)now[i]=0;
dis[to[i]]=w[i];
now[w[i]]++;
getdis(to[i],x);//计算路径信息
ans+=now[1]*num[2]*2+now[2]*num[1]*2+now[0]*num[0]*2;//计算贡献
for(int i=0;i<3;i++)num[i]+=now[i];
}
}
void solve(int x)
{
vis[x]=true;//先标记已经处理过的部分
work(x);//处理当前子树答案
for(int i=head[x];i;i=nx[i])
if(!vis[to[i]])
{
rt=0;
sum=size[x];
getrt(to[i],x);
solve(rt);//递归处理各大子树
}
}
int gcd(int a,int b)
{
return b==0?a:gcd(b,a%b);
}
int main()
{
cin>>n;
int u,v,W;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&u,&v,&W);
W%=3;
addroad(u,v,W,i);
addroad(v,u,W,i+n);
}
maxs[rt]=sum=n;
getrt(1,0);//找重心
solve(rt);//开始处理
ans+=n;
int GCD=gcd(ans,n*n);
cout<<ans/GCD<<"/"<<n*n/GCD;//最后除个 gcd 就好
}