题面
题目描述
给一颗n个节点的树,每条边上有一个距离v(v<=1000)。
定义d(u,v)为u到v的最小距离;
给定k值,求有多少点对(u,v)使u到v的距离小于等于k。
输入格式
输入有多组数据,以两个0结尾
每组数据第一行两个整数n,k,n<=10000,k<2^31
接下来n-1行,每行三个整数x,y,v,表示点x到点y有一条边距离是v,v<=1000
输出格式
对每组数据一行一个答案。
input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
output
8
题解
这道题目是一个树分治——点分治的典型例题。
点分治流程:
1、寻找当前树的重心。
2、处理经过根结点的路径。
3、标记根结点为已访问。
4、递归处理以根节点的儿子为根的每棵子树。
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
int num=0;
char c=' ';
bool flag=true;
for(;c>'9'||c<'0';c=getchar())
if(c=='-')
flag=false;
for(;c>='0'&&c<='9';num=(num<<3)+(num<<1)+c-48,c=getchar());
return flag ? num : -num;
}
const int maxn=10021;
int root;
int dcnt=0,ans=0;
//dcnt是存放了dis数组中数据的数量,ans存放了最终统计出来的数量
int now,vis[maxn],dis[maxn],d[maxn];
//now存放了当前子树的规模
//vis[i]代表i是否被访问
//d[i]存放了到根的距离
const int INF=2e9;
int f[maxn],size[maxn];
//这两个数组是用来求树的重心的
int n,k;
struct node
{
int dot,val,next;
}a[maxn*2];
int top=0,head[maxn];
//邻接表
void insert(int x,int y,int v)
{
top++;
a[top].dot=y;
a[top].next=head[x];
a[top].val=v;
head[x]=top;
}
//插入边
void dfs(int u,int fa)
{
size[u]=1;f[u]=0;
for(int i=head[u];i;i=a[i].next)
{
int v=a[i].dot;
if(v==fa||vis[v])continue;
dfs(v,u);
size[u]+=size[v];
f[u]=max(f[u],size[v]);
}
f[u]=max(f[u],now-size[u]);
if(f[u]<f[root])root=u;
}
//求树的重心放在变量root里
void get_dis(int u,int fa)
{
if(d[u]>k)return;
dis[++dcnt]=d[u];
for(int i=head[u];i;i=a[i].next)
{
int v=a[i].dot;
if(v==fa||vis[v])continue;
d[v]=d[u]+a[i].val;
get_dis(v,u);
}
}
int add(int u,int w)
//u的子树中,如果有两个点i,j
//满足d[i]+d[j]<k-w的数量
{
d[u]=w;dcnt=0;
get_dis(u,0);
sort(dis+1,dis+dcnt+1);
//排序一下,可以直接枚举两个端点,不用一个一个搜索过去
int l=1,r=dcnt,sum=0;
while(l<r)
{
if(dis[l]+dis[r]<=k)
{
sum+=r-l;
l++;
}
else r--;
}
return sum;
}
void solve(int u)
{
ans+=add(u,0);
vis[u]=1;
for(int i=head[u];i;i=a[i].next)
{
int v=a[i].dot;
if(vis[v])continue;
//如果访问过了,就说明该子树已经被解决
ans-=add(v,a[i].val);
now=size[v];root=0;
dfs(v,0);
//求子树的重心
solve(root);
//子树继续递归解决
}
}
void init()
{
now=n=read();
k=read();
while(n&&k)
{
memset(vis,0,sizeof vis);
top=ans=root=0;
f[0]=maxn;
memset(head,0,sizeof head);
for(int i=1;i<n;i++)
{
int x=read();
int y=read();
int v=read();
insert(x,y,v);
insert(y,x,v);
}
dfs(1,0);
solve(root);
printf("%d\n",ans);
now=n=read();
k=read();
}
}
int main()
{
init();
return 0;
}