题目链接:http://poj.org/problem?id=1741
给定一颗树,以及k,一次询问,找出所有路径长度小于等于k的点对数。
前置:flag[]数组表示删除的点,d[]数组代表节点到根节点的距离,b[]数组代表节点属于哪一颗子树,a[]子树中所有的节点(并按照d从小到大排序),size1[]数组表示编号为i的子树的大小,cnt[]数组表示编号为i的子树出现的次数。
首先我们先选择一个点作为树根,将无根树转换为有根树,树根为p,那么节点x与节点y之间的路径就有两种情况:
如图是一颗树,第一次选取了1作为重心,将树划分成绿色的三部分,那么第一种路径就是2节点所在的块的点与3节点所在的块的点之间存在的点对数。
对于第二种路径,我们对2子树进行划分,类似第一次那种的划分,将子树分为了红色的四部分,其中2是该子树的重心,然后计算该子树中以2为树根的第一种路径的个数,接下来的划分也是同理。
于是通过这种方法将第二种路径转换为第一种路径求解即可。
即第一种路径的计算是由属于不同连通块的点之间得来,第二种路径的计算是将一个连通块通过划分转换为有根树求解。
- 经过树根的,那么x->y的路径可看做:x->p,p->y的路径和,那么这种情况下,可以预处理所有的点到树根的距离即可。
- 不经过树根的,那么这两个点一定在p的不包含p节点的一颗子树中,对于这种情况,我们递归深入的选取该子树的一个点作为子树树根,进行计算即可。
对于第一种路径,我们需要统计b[x]!=b[y]并且d[x]+d[y]<=k的点对。
介绍一种方法:指针扫描数组
将一颗树的所有节点放入a数组,并排序,再用两个指针变量l,r,分别从前往后,从后往前扫描a数组,并且易得知在l从前向后扫的过程中,满足d[a[l]]+d[a[r]]<=k的r的范围从后向前是单调递减的。
同时统计[l+1,r]范围内各个编号的子树出现的次数,用cnt[]数组,那么当d[a[l]]+d[a[r]]<=k时,点对数就是:r-l-cnt[ b[ a[l] ] ]。
整个点分治算法的过程就是:
- 选取树的重心作为根节点p
- 从p出发求出d,b数组
- 计算[l,r]内的点对
- 删除根节点p,对p的每棵子树递归执行1~4步。
如果我们任意的选取一个点作为树根的话,很可能会遇到最坏的情况就是一条链,递归深度太深,整个算法复杂度会退化到O(N*N*logN),若选取树的重心的话,树的深度最多到logN,那么复杂度就是O(N*logN*logN);
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<iostream>
using namespace std;
typedef long long ll;
const int maxn=1e4+7;
bool vis[maxn];
bool flag[maxn];
int size1[maxn];//以i为根的子树大小;
int res;
struct Edge{
int v,w,next;
}edge[maxn<<1];
int head[maxn],top;
int d[maxn],b[maxn];
void add(int u,int v,int w){
edge[top].v=v;
edge[top].w=w;
edge[top].next=head[u];
head[u]=top++;
}
void init(){
top=0;
res=0;
memset(head,-1,sizeof(head));
memset(flag,0,sizeof(flag));
memset(size1,0,sizeof(size1));
memset(b,0,sizeof(b));
memset(d,0,sizeof(d));
}
int rt;//树的重心;
int ans;
int n,k;
int a[maxn];//记录子树中所有的节点,并按照dis[a[i]]递增的方式排序;
int cnt[maxn];//记录编号为i的子树的大小;
bool cmp(int a,int b){
return d[a]<d[b];
}
//求解节点个数为zong个,的树的重心;
void get_rt(int u,int zong){
vis[u]=1;
size1[u]=1;
int maxx=0;
int v;
for(int i=head[u];i!=-1;i=edge[i].next){
v=edge[i].v;
if(vis[v]||flag[v]) continue;
get_rt(v,zong);
size1[u]+=size1[v];
maxx=max(maxx,size1[v]);
}
maxx=max(maxx,zong-size1[u]);
if(maxx<ans){
ans=maxx;
rt=u;
}
}
int t;//a数组的长度;
//求解b,d数组,同时记录a,cnt数组;
void dfs(int u,int fa){
++cnt[fa];
a[t++]=u;
b[u]=fa;
vis[u]=1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v]||flag[v]) continue;
d[v]=d[u]+edge[i].w;
dfs(v,fa);
}
}
//划分树;
void work(int zong,int u){
t=0;
memset(cnt,0,sizeof(cnt));
memset(vis,0,sizeof(vis));
rt=u;
ans=zong;
get_rt(u,zong);
//rt为树的重心,rt自身为一颗子树,其余有直连边的点属于在各自的子树中;
a[t++]=rt;
b[rt]=rt;
++cnt[rt];
flag[rt]=1;
d[rt]=0;
memset(vis,0,sizeof(vis));
vis[rt]=1;
for(int i=head[rt];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(vis[v]||flag[v]) continue;
d[v]=0+edge[i].w;
dfs(v,v);
}
sort(a,a+t,cmp);
int l=0,r=t-1;
//计算[l+1,r]区间内满足d[a[l]]+d[a[j]]<=k&& b[a[j]]!=b[a[l]];的点数;
/*求解第一种路径*/
--cnt[b[a[0]]];
while(l<r){
while(l<r&&d[a[l]]+d[a[r]]>k){
--cnt[b[a[r]]];
--r;
}
if(l>=r) break;
res+=r-l-cnt[b[a[l]]];
++l;
--cnt[b[a[l]]];
}
/*划分子树,求第二种路径*/
int now=rt;
for(int i=head[now];i!=-1;i=edge[i].next){
int v=edge[i].v;
if(flag[v]) continue;
work(size1[v],v);
}
}
void solve(){
init();
int u,v,w;
for(int i=0;i<n-1;++i){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
work(n,1);
printf("%d\n",res);
}
int main(){
while(scanf("%d%d",&n,&k)!=EOF){
if(n+k==0) break;
solve();
}
return 0;
}