题目描述:有n个点和n-1条边,dist(u,v)为u和v之间的最小距离,问两个点之间的dist(u,v)<=k的uv对数。
做法:假如枚举每一个点暴力跑肯定是超时的,那么这道题就需要lgn的级数来做,看了大牛们的做法是树上分治。大体思想为每回求出当前的树重心,以重心为根得到重心到子结点<=k的距离,大于k就没必要要了。对得到距离进行排序,然后可以两边从中间跑,当dis(l)+dis(r)<=k时方案数有r-l种,但是我们要的是重心两边的子树之间的方案数,而不是单独一边子树内的方案数。而当前的r-l方案中也包含了在同一颗子树内的两个结点的方案数,需要对这棵子树单独跑出距离进行同样的处理减去这颗子树的r-l种方案。最后结束后再次调用一整个函数对当前重心的的子树分别进行调用,一直递归下去最终得到答案。
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<vector>
#include<algorithm>
using namespace std;
const int maxn=1e5+10,inf=1e9;
int k,ans,tot;
struct node
{
int v,w;
node(int _v=0,int _w=0)
{
v=_v;
w=_w;
}
};
vector<node>e[maxn];
struct node1
{
int sum,big;
}vec[maxn];
int sign[maxn],size[maxn],dist[maxn],vis[maxn];
int max(int a,int b)
{
return a>b? a:b;
}
void dfs(int u,int fa)
{
vec[u].sum=0;
vec[u].big=0;
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i].v;
if(fa!=v&&vis[v]==0)
{
dfs(v,u);
vec[u].sum+=vec[v].sum;
vec[u].big=max(vec[u].big,vec[v].sum);
}
}
vec[u].sum++;
sign[tot]=u;
size[tot++]=vec[u].big;
}
int Getroot(int u)
{
tot=0;
dfs(u,-1);
int maxx=inf,maxi=-1,cnt=vec[u].sum;
for(int i=0;i<tot;i++)
{
size[i]=max(size[i],cnt-size[i]);
if(size[i]<=maxx)
{
maxx=size[i];
maxi=sign[i];
}
}
return maxi;
}
void Getdist(int u,int fa,int dis)
{
dist[tot++]=dis;
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i].v;
if(fa!=v&&vis[v]==0&&dis+e[u][i].w<=k)
{
Getdist(v,u,dis+e[u][i].w);
}
}
}
void Count1(int u)
{
sort(dist,dist+tot);
int l=0,r=tot-1;
while(l<r)
{
if(dist[l]+dist[r]<=k)
{
ans+=r-l;
++l;
}
else --r;
}
}
void Count2(int u)
{
vis[u]=1;
for(int i=0;i<e[u].size();i++)
{
int v=e[u][i].v;
if(vis[v]==0)
{
tot=0; Getdist(v,u,e[u][i].w);
sort(dist,dist+tot);
int l=0,r=tot-1;
while(l<r)
{
if(dist[l]+dist[r]<=k)
{
ans-=r-l;
++l;
}
else --r;
}
}
}
}
void solve(int u,int fa)
{
int root=Getroot(u);
tot=0; Getdist(root,-1,0);
Count1(root);
Count2(root);
for(int i=0;i<e[root].size();i++)
{
int v=e[root][i].v;
if(v!=fa&&vis[v]==0)
{
solve(v,root);
}
}
}
int main()
{
int n;
while(~scanf("%d %d",&n,&k),n+k)
{
for(int i=1;i<=n;i++)
e[i].clear();
for(int i=1;i<n;i++)
{
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
e[u].push_back(node(v,w));
e[v].push_back(node(u,w));
}
ans=0;
memset(vis,0,sizeof(vis));
solve(1,-1);
printf("%d\n",ans);
}
}