题意:给一棵树,问两个结点之间的距离小于等于K的有多少对。
思路:男人八题之一。使用树分治解决。简单来说,对于一棵树的。那么答案分为三种情况,1,两个结点在该根结点的子树上,此情况递归解决;2,两个结点在根的两个子树上,对于这种情况,我们可以找出所有结点到该根节点的距离,这个过程是O(n),对距离排序(O(nlogn))以后,可以O(n)求出小于等于K的对数,但注意这种方法和情况1中有重复,要删掉;3,这两个结点,一个是根另一个在子树上,这种情况可以加1个距离是0的结点来处理。这样分治就可以做了,但是对于极端情况,整个树是一条链,递归的高度将达O(n),整体复杂度将达O(n^2logn)会超时。解决办法在于如何分,这里用到了重心这个概念,即删掉该点以后最大子树的结点数最小的点。这样每次从重心上递归解决就行了。至于怎么求重心呢,每次找结点数最多的结点递归解决就行了。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
using namespace std;
const int INF=0x7fffffff;
struct Edge
{
int to,cost;
Edge(int a=0,int b=0):to(a),cost(b){}
};
const int maxn=100005;
bool isPoint[maxn];
int n,K;
vector<Edge> g[maxn];
int subsize[maxn];
void init(int n)
{
for(int i=1; i<=n; ++i)
g[i].clear(),subsize[i]=0,isPoint[i]=0;
}
int getSubSize(int rt,int fa)
{
subsize[rt]=1;
for(int i=0;i<g[rt].size();++i)
{
int v=g[rt][i].to;
if(v==fa||isPoint[v]) continue;
subsize[rt]+=getSubSize(v,rt);
}
return subsize[rt];
}
pair<int,int> getPoint(int rt,int fa,int t)
{
pair<int,int> res(INF,-1);
int m=0,s=1;
for(int i=0;i<g[rt].size();++i)
{
int v=g[rt][i].to;
if(v==fa||isPoint[v]) continue;
res=min(res,getPoint(v,rt,t));
m=max(m,subsize[v]);
s+=subsize[v];
}
m=max(m,t-s);
res=min(res,make_pair(m,rt));
return res;
}
void getLength(int rt,int fa,int d,vector<int> &ds)
{
ds.push_back(d);
for(int i=0;i<g[rt].size();++i)
{
int v=g[rt][i].to;
if(v==fa||isPoint[v]) continue;
getLength(v,rt,d+g[rt][i].cost,ds);
}
}
int getCount(vector<int> &ds)
{
int ans=0;
sort(ds.begin(),ds.end());
int st=0,ed=ds.size()-1;
while(st<ed)
{
if(ds[st]+ds[ed]>K) --ed;
else
{
ans+=(ed-st);
++st;
}
}
return ans;
}
int ans;
void solve(int rt)
{
getSubSize(rt,-1);
int s=getPoint(rt,-1,subsize[rt]).second;
isPoint[s]=true;
for(int i=0;i<g[s].size();++i)
{
int v=g[s][i].to;
if(isPoint[v]) continue;
solve(v);
}
vector<int> ds;
ds.push_back(0);
for(int i=0;i<g[s].size();++i)
{
int v=g[s][i].to;
if(isPoint[v]) continue;
vector<int> tds;
getLength(v,s,g[s][i].cost,tds);
ans-=getCount(tds);
ds.insert(ds.end(),tds.begin(),tds.end());
}
ans+=getCount(ds);
isPoint[s]=false;
}
int main()
{
while(scanf("%d%d",&n,&K)!=EOF)
{
if(!n&&!K)
break;
init(n);
for(int i=1;i<n;++i)
{
int a,b,c;
scanf("%d%d%d",&a,&b,&c);
g[a].push_back(Edge(b,c));
g[b].push_back(Edge(a,c));
}
ans=0;
solve(1);
printf("%d\n",ans);
}
return 0;
}