作用
在树上实现,通过分治的方法,使原来需要复杂度为n^2的做法变为n*logn.
实现方法
首先找到树的重心,然后以重心为根,考虑需要用到根的部分,之后去掉根节点,将原来子树分成多个小于原来大小的一半的多棵子树,再用上述方法反复处理,不难证明复杂度为n*logn.
因此点分的核心代码就是统计每个点的子树大小和找到根节点.
void getsz(int now,int last)
{
int p,q;
size[now]=1;
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(vis[bn[p].to]||bn[p].to==last) continue;
getsz(bn[p].to,now);
size[now]+=size[bn[p].to];
}
}
int getrt(int now,int last,int tot)
{
int p,q;
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(vis[bn[p].to]||bn[p].to==last||size[bn[p].to]*2<tot) continue;
return getrt(bn[p].to,now,tot);
}
return now;
}
例题poj 1741 Tree
给出一棵有边权树,问最短距离小于等于k的点对有几对.
做法
用点分的方法,每次考虑通过根节点的路径长度,用递归求出每个点到根节点的距离,排序之后可以用O(n)的复杂度求出有几个点对符合,再用同样的方法减掉在同一棵子树中的点
对(在统计那一棵子树时会重复计算).
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 10010
#define INF 0x3f3f3f3f
using namespace std;
int n,k,first[N],bb,size[N],ans,deep[N],de[N],dd;
bool vis[N];
struct bn
{
int to,next,quan;
} bn[N<<1];
inline void add(int u,int v,int w)
{
bb++;
bn[bb].to=v;
bn[bb].next=first[u];
bn[bb].quan=w;
first[u]=bb;
}
void getsz(int now,int last)
{
int p,q;
size[now]=1;
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(vis[bn[p].to]||bn[p].to==last) continue;
getsz(bn[p].to,now);
size[now]+=size[bn[p].to];
}
}
int getrt(int now,int last,int tot)
{
int p,q;
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(vis[bn[p].to]||bn[p].to==last||size[bn[p].to]*2<tot) continue;
return getrt(bn[p].to,now,tot);
}
return now;
}
void gd(int now,int last)
{
int p,q;
de[++dd]=deep[now];
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(bn[p].to==last||vis[bn[p].to]) continue;
deep[bn[p].to]=deep[now]+bn[p].quan;
gd(bn[p].to,now);
}
}
inline int js(int now,int cz)
{
int i,j,res=0;
dd=0;
deep[now]=cz;
gd(now,-1);
sort(de+1,de+dd+1);
for(i=1,j=dd; i<j;)
{
de[i]+de[j]<=k?res+=j-i,i++:j--;
}
return res;
}
void work(int now)
{
int p,q;
ans+=js(now,0);
vis[now]=1;
for(p=first[now]; p!=-1; p=bn[p].next)
{
if(vis[bn[p].to]) continue;
ans-=js(bn[p].to,bn[p].quan);
getsz(bn[p].to,-1);
work(getrt(bn[p].to,-1,size[bn[p].to]));
}
}
int main()
{
int i,j,p,q,o;
for(;;)
{
scanf("%d%d",&n,&k);
if(!n&&!k) return 0;
bb=ans=0;
memset(first,-1,sizeof(first));
memset(vis,0,sizeof(vis));
for(i=1; i<n; i++)
{
scanf("%d%d%d",&p,&q,&o);
add(p,q,o),add(q,p,o);
}
getsz(1,-1);
work(getrt(1,-1,n));
printf("%d\n",ans);
}
}