类似的还有POJ 1987
题意:给出一棵树,问树上点对<A,B>所代表的简单路径长度<=K的不同点对数。
解法:树分治经典题了,求重心分治后,dfs求当前子树所有点到重心的距离,将求得的距离排序,用指针扫描法O(N)统计路径长度<=K的点对数量,累加即是答案,注意需要减去同一颗子树上的点对。
代码:
#include<stdio.h>
#include<string.h>
#include<algorithm>
#include<math.h>
#include<iostream>
#include<stdlib.h>
#include<set>
#include<map>
#include<queue>
#include<vector>
#include<bitset>
#pragma comment(linker, "/STACK:1024000000,1024000000")
template <class T>
bool scanff(T &ret){ //Faster Input
char c; int sgn; T bit=0.1;
if(c=getchar(),c==EOF) return 0;
while(c!='-'&&c!='.'&&(c<'0'||c>'9')) c=getchar();
sgn=(c=='-')?-1:1;
ret=(c=='-')?0:(c-'0');
while(c=getchar(),c>='0'&&c<='9') ret=ret*10+(c-'0');
if(c==' '||c=='\n'){ ret*=sgn; return 1; }
while(c=getchar(),c>='0'&&c<='9') ret+=(c-'0')*bit,bit/=10;
ret*=sgn;
return 1;
}
#define inf 1073741823
#define llinf 4611686018427387903LL
#define PI acos(-1.0)
#define lth (th<<1)
#define rth (th<<1|1)
#define rep(i,a,b) for(int i=int(a);i<=int(b);i++)
#define drep(i,a,b) for(int i=int(a);i>=int(b);i--)
#define gson(i,root) for(int i=ptx[root];~i;i=ed[i].next)
#define tdata int testnum;scanff(testnum);for(int cas=1;cas<=testnum;cas++)
#define mem(x,val) memset(x,val,sizeof(x))
#define mkp(a,b) make_pair(a,b)
#define findx(x) lower_bound(b+1,b+1+bn,x)-b
#define pb(x) push_back(x)
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
#define NN 20020
int n,k;
int ptx[NN],lnum;
struct edge{
int v,next,w;
edge(){}
edge(int v,int next,int w){
this->v=v;
this->next=next;
this->w=w;
}
}ed[NN*2];
void addline (int x,int y,int w){
ed[lnum]=edge(y,ptx[x],w);
ptx[x]=lnum++;
}
int sz[NN],maxv[NN],maxval,dis[NN],dn;
int center;
bool vis[NN];
int getsize(int x,int fa){
sz[x]=1;
gson(i,x){
int y=ed[i].v;
if(y==fa||vis[y])continue;
sz[x]+=getsize(y,x);
}
return sz[x];
}
void getcenter(int r,int x,int fa){
maxv[x]=0;
gson(i,x){
int y=ed[i].v;
if(y==fa||vis[y])continue;
getcenter(r,y,x);
maxv[x]=max(maxv[x],sz[y]);
}
maxv[x]=max(maxv[x],sz[r]-sz[x]);
if(maxv[x]<maxval)maxval=maxv[x],center=x;
}
void getdis(int x,int fa,int d){
dis[++dn]=d;
gson(i,x){
int y=ed[i].v;
if(y==fa||vis[y])continue;
getdis(y,x,d+ed[i].w);
}
}
int calc(int x,int d){
dn=0;
getdis(x,0,d);
sort(dis+1,dis+1+dn);
int j=dn;
int sum=0;
rep(i,1,dn){
while(dis[i]+dis[j]>k&&j>=i)j--;
if(i>=j)break;
sum+=j-i;
}
return sum;
}
int ans;
int solve(int x){
maxval=inf;
getsize(x,0);
getcenter(x,x,0);
vis[center]=true;
ans+=calc(center,0);
gson(i,center){
int y=ed[i].v;
if(vis[y])continue;
ans-=calc(y,ed[i].w);
solve(y);
}
}
void init(){
ans=0;
lnum=0;
rep(i,0,n)ptx[i]=-1,vis[i]=false;
}
int main(){
while(scanf("%d%d",&n,&k)!=EOF){
if(n==0&&k==0)break;
init();
rep(i,1,n-1){
int x,y,w;
scanff(x);
scanff(y);
scanff(w);
addline(x,y,w);
addline(y,x,w);
}
solve(1);
printf("%d\n",ans);
}
return 0;
}