树的点分治,就是在对树分而治之。复杂度最坏情况下是nlogn
比较好的练手题:poj-1741 http://poj.org/problem?id=1741
题意:求树上距离小于等于K的点对有多少个
关于重心。
对于分治问题,我们都希望能够令分开的问题尽可能的平均,树的点分治也是这个道理。所以我们要先求出树的重心,即用这个重心,将树分成几个相对平均的部分。
int sz[maxn],maxx[maxn];///当前树大小,最大子树大小
void getG(int x,int fa,int sum){
maxx[x]=0; sz[x]=1;
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to;ll w=e[i].val;
if(!vis[t]&&t!=fa){
getG(t,x,sum);
sz[x]+=sz[t],maxx[x]=max(sz[t],maxx[x]);
}
}
maxx[x]=max(maxx[x],sum-sz[x]);
if(maxx[G]>maxx[x]) G=x;
}
关于单个子树对其所有结点的处理。
我们需要把这个结点为根的所有的结点的距离进行计算并保存,用类似尺取的方法定下一个左指针和一个右指针,因为每个长度都指向一个结点,枚举左指针为边界,找第一个相加小于等于k的右节点,那么中间所有的结点相加也都将小于k,所以可以直接进行处理。
void dfs(int x,int fa){
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to; ll w=e[i].val;
if(vis[t]||fa==t) continue;
len[++num]=dep[t]=dep[x]+w;
dfs(t,x);
}
}
int Cal(int x,int de){
int sum=0;
num=0;
len[++num]=dep[x]=de;
dfs(x,0);
sort(len+1,len+1+num);
for(int l=1,r=num;r>l;){
if(len[r]+len[l]<=k) sum+=r-l,l++;
else r--;
}
return sum;
}
关于对其直接连接的儿子进行的减操作。
新手看到这里的人可能会觉得有点奇怪,为什么在加了这个子树x之后还要减去其儿子u的值呢。因为如果有两个点同时在该儿子u的子树下,那么这两个结点a1,a2的距离差不会是这两个到x的和。所以要进行减去。
void div(int x){
vis[x]=1;
ans+=Cal(x,0);
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to; ll w=e[i].val;
if(!vis[t]){
ans-=Cal(t,w);
G=0;
maxx[0]=sz[t];
getG(t,0,sz[t]);
div(G);
}
}
}
然后就可以直接凑起来啦。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int maxn=10005;
const int inf=1e9+7;
struct node{
int to,next;
ll val;
}e[maxn<<1];
int now,head[maxn],k,n,G,dep[maxn],dis[maxn],num;
int sz[maxn],maxx[maxn],vis[maxn],len[maxn];//当前树大小,最大子树大小
ll ans;
void add(int u,int v,ll w){
e[now].to=v,e[now].next=head[u];
e[now].val=w,head[u]=now++;
}
void getG(int x,int fa,int sum){
maxx[x]=0; sz[x]=1;
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to;ll w=e[i].val;
if(!vis[t]&&t!=fa){
getG(t,x,sum);
sz[x]+=sz[t],maxx[x]=max(sz[t],maxx[x]);
}
}
maxx[x]=max(maxx[x],sum-sz[x]);
if(maxx[G]>maxx[x]) G=x;
}
void dfs(int x,int fa){
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to; ll w=e[i].val;
if(vis[t]||fa==t) continue;
len[++num]=dep[t]=dep[x]+w;
dfs(t,x);
}
}
int Cal(int x,int de){
int sum=0;
num=0;
len[++num]=dep[x]=de;
dfs(x,0);
sort(len+1,len+1+num);
for(int l=1,r=num;r>l;){
if(len[r]+len[l]<=k) sum+=r-l,l++;
else r--;
}
return sum;
}
void div(int x){
vis[x]=1;
ans+=Cal(x,0);
for(int i=head[x];~i;i=e[i].next){
int t=e[i].to; ll w=e[i].val;
if(!vis[t]){
ans-=Cal(t,w);
G=0;
maxx[0]=sz[t];
getG(t,0,sz[t]);
div(G);
}
}
}
int main(){
int x,y; ll w;
while(~scanf("%d%d",&n,&k)){
if(n==0&&k==0) break;
memset(head,-1,sizeof(head)); now=0,ans=0;
memset(vis,0,sizeof(vis));
for(int i=1;i<n;i++){
scanf("%d%d%lld",&x,&y,&w);
add(x,y,w); add(y,x,w);
}
maxx[0]=n;
G=0; getG(1,0,n);
div(G);
printf("%lld\n",ans);
}
return 0;
}