题意:统计距离<=k的点对,n<=10000.
这道题用的树的点分治方法。点分治基于树的重心。
树的重心,定义是删除某个点后得到的最大(节点数)子树的节点数最小。性质是,可以证明删除掉重心后,每个子树的大小<=n/2。这个性质保证了基于重心的分治算法深度不会超过logn。
这题递归求解子树后,将所有子树节点d[i]距离排序,用两个指针扫过去就可以得到对于每一个d[i]满足的d[j]+d[i]<=k的j的个数。要注意的是减去重复的个数(这个比较麻烦)。
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long ll;
const int inf=0x3f3f3f3f;
const int maxn=1e4+10;
const int maxm=maxn<<1;
int n,k;
int tot,first[maxn],nxt[maxm], to[maxm], cost[maxm];
int vis[maxn],siz[maxn],maxs[maxn]={inf}, dep[maxn];
int a[maxn], in[maxn];
void addedge(int u, int v, int w)
{
nxt[tot]=first[u];
to[tot]=v;
cost[tot]=w;
first[u]=tot++;
}
void init()
{
memset(first ,-1, sizeof(first));
tot=0;
memset(vis, 0, sizeof(vis));
for(int i=1; i<n; i++){
int u,v,w;
scanf("%d%d%d", &u, &v, &w);
addedge(u,v,w);
addedge(v,u,w);
}
}
void getsize(int u, int pre) //记录节点数
{
siz[u]=1;
maxs[u]=0;
for(int i=first[u]; i!=-1; i=nxt[i]){
int v=to[i];
if(v==pre || vis[v]) continue;
getsize(v, u);
siz[u]+=siz[v];
maxs[u]=max(maxs[u], siz[v]);
}
}
void getroot(int u, int pre, int num, int &rt)//获取重心
{
maxs[u]=max(maxs[u], num-siz[u]);
if(maxs[u]<maxs[rt])
rt=u;
for(int i=first[u]; i!=-1; i=nxt[i]){
int v=to[i];
if(v==pre||vis[v]) continue;
getroot(v, u, num, rt);
}
}
void getdep(int u, int pre, int d, int &cnt)
{
a[cnt++]=d;
dep[u]=d;
for(int i=first[u]; i!=-1; i=nxt[i]){
int v=to[i], w=cost[i];
if(v==pre||vis[v]) continue;
getdep(v, u, d+w, cnt);
}
}
ll count(int l, int r) //对于排好序的数组,用两个指针扫描得到计数。
{
ll ret=0;
int ptr=r;
for(int i=l; i<=r && a[i]<=k && ptr>=l; i++){
while(a[ptr]+a[i]>k && ptr>=l) ptr--;
ret+=ptr-l+1;
}
return ret;
}
ll solve(int u)
{
ll ret=0;
int rt=0;
getsize(u, -1);
getroot(u, -1, siz[u], rt);
vis[rt]=1;
for(int i=first[rt]; i!=-1; i=nxt[i]){
int v=to[i], w=cost[i];
if(vis[v])continue;
ret+=solve(v);//递归求解子树
}
ll tmp=0;
int cnt=0;
for(int i=first[rt]; i!=-1; i=nxt[i]){
int v=to[i], w=cost[i];
if(vis[v]) continue;
int p=cnt;
getdep(v, rt, w, cnt);
sort(a+p, a+cnt);
tmp-=count(p, cnt-1);
}
a[cnt++]=0;
sort(a, a+cnt);
tmp+=count(0, cnt-1);
ret+=(tmp-1)/2;
vis[rt]=0;
return ret;
}
int main()
{
while(cin>>n>>k &&n+k){
init();
ll ans=solve(1);
cout<<ans<<endl;
}
return 0;
}