树上点分治,用来处理这样一些问题,即给定一棵树,求出和树上所有点对之间距离相关的某个答案,如求树上两点之间距离等于k的点对个数(CF161D,板子题)。
如果了解lca,可能会想到利用lca求出所有点对间的距离,然后枚举点对统计答案,复杂度大概是O(n2logn)。不过看一眼题,n最大为5e5,即使时间给了三秒,也应该会TLE。
点分治的思想大概 是对于给定的树,通过选取其中的特殊点(后面讲),将其拆为不同子树,并对子树进行取点、拆分操作,然后分别对每棵子树计算答案,最后统计结果。
为什么这样做是对的呢
简单解释一下毕竟这算法我也不是很熟 ,考虑一棵树上满足条件的路径可能由以下情况组成:1,在树的某棵子树上取某些边;2,跨过树根root,在两边的子树中取边组成。
不过再细想一下,对于情况1的那棵子树,再把它看作几棵子树,那么成路径的方式其实和情况2一样,即其路径是两个“小子树”上选取边跨过这棵“大子树”的根组成的。
这样基本思路可能 就出来了,分出子树后考虑子树就行了。
先解决找子树的根的问题
找根不能随便找,因为要显得算法厉害 节省遍历的时间,所以要找的点就应该满足“以这个点作为树根,其最大的子树要尽量小”这个条件。这个点就叫重心。
对着代码讲:
const int maxn=50000+5;
const int inf=1e9;
int n,k,p,cntn;//cntn可以理解为当前子树要考虑的点的总个数,因为题中要求统计点对个数,没问具体是哪些点对,这样写着方便一点
int h[maxn],w[maxn*2],v[maxn*2],nxt[maxn*2];//邻接表大法好
ll ans;//最终答案
int mxsont[maxn],sz[maxn],focus,sn;//mxsont[i]记录以i为根的最大子树的大小,sz[i]表示i的子树大小,focus记录当前重心,sn是当前考虑的“大子树”的大小
bool deleted[maxn];//计算完一个点的子树,就把这个点记录为删去
ll dis[maxn];//dis[i]记录i到当前子树树根的距离
把要定义的交代清楚,下面就是找重心的过程:
void getfocus(int x,int fa)
{
mxsont[x]=0;
sz[x]=1;//初始化
for(int i=h[x];i;i=nxt[i])
{
int to=v[i];
if(to==fa||deleted[to])continue;//判重,防止搜索搜回去
getfocus(to,x);
sz[x]+=sz[to];//加上子树的大小
if(sz[to] > mxsont[x])mxsont[x]=sz[to];
}
if((sn-sz[x]) > mxsont[x])mxsont[x]=sn-sz[x];//???
if(mxsont[x] < mxsont[focus])focus = x;
}
解释下问号那一句,画个图:
这里假如我们递归到红色节点x,计算完其子树大小,后来发现它的“头上”还有一部分,如果要把x当作树根,那么上面那一部分就应该是它的子树,所以也需要比较一下,更新x的最大子树大小。
好了到这里我们就退赛了 找到了重心。
之后,在每棵子树上,计算节点到树根的距离
dfs大法好
void getdis(int x,int fa,int d)
{
for(int i=h[x]; i; i=nxt[i])
{
int to=v[i];
if(to==fa||deleted[to])continue;
dis[++cntn]=d+w[i];
getdis(to,x,dis[cntn]);
}
}
别忘了cntn是当前子树节点个数
现在开始对每棵子树统计答案
int lookforl(int l,int check)//找可能小于check的最小数
{
int findl=0;
int r=cntn;
while(l<=r)
{
int mid=(l+r)>>1;
if(dis[mid]<check)
{
l=mid+1;
}
else
{
findl=mid;
r=mid-1;
}
}
return findl;
}
int lookforr(int l,int check)//找可能小于check的最大数
{
int findr=0;
int r=cntn;
while(l<=r)
{
int mid=(l+r)>>1;
if(dis[mid]<=check)
{
findr=mid;
l=mid+1;
}
else
{
r=mid-1;
}
}
return findr;
}
int calansx(int x,int fad)//统计x为根的子树上的答案,fad可以理解为x到x的父亲的距离,后面会讲
{
cntn=1;
dis[1]=fad;
getdis(x,0,fad);
sort(dis+1,dis+cntn+1);//二分查找别忘了排序
int l=1;
int ansx=0;
while(l<cntn&&(dis[l]+dis[cntn]<k))l++;
while(l<cntn&&(k-dis[l]>=dis[l]))
{
int l1=lookforl(l+1,k-dis[l]);
int r1=lookforr(l+1,k-dis[l]);
if(r1>=l1)
{
ansx+=r1-l1+1;
}
++l;
}
return ansx;
}
这里用到了一个二分,初次看比较难以理解,不过目的是很明确的,毕竟你要是枚举子树上的所有点统计答案也是可以试试的 还是会TLE
重点来了,分治函数
void dfsansn(int x)
{
deleted[x]=1;
ans+=calansx(x,0);//????????
for(int i=h[x];i;i=nxt[i])
{
int to=v[i];
if(deleted[to])continue;
ans-=calansx(to,w[i]);//????????
sn=sz[to];
focus=0;
getfocus(to,x);
dfsansn(focus);
}
}
这里就要讲一讲问号的两句了,看图
在对x统计答案的时候,对x下面的点a,b,计算dis<a,x>+dis<x,b>来统计答案是不合法的,因为我们重复计算了x.to到x之间的距离,所以当我们计算完x为根时子树答案后,减去以to为根的子树的非法路径统计出的答案,而这些路径之所以非法是因为计算两点间距离时是加了to和x之间的距离的,学长说这是容斥原理但是他没讲清楚 我没学过(还是菜啊)
到这里,这题就已经A了
int main()
{
n=getnum();
k=getnum();
for(int i=1;i<n;i++)
{
int a=getnum(),b=getnum();
add(a,b,1);
add(b,a,1);
}
sn=n;
mxsont[0]=inf;
getfocus(1,0);
dfsansn(focus);
cout<<ans<<"\n";
return 0;
}
当然还差这些神秘代码。
终于码完了,毕竟是初学,某些地方的解释可能不是很严谨,欢迎指正。
另附上其他题目:洛谷P3806
POJ1741