Problem
Description
给出一棵带边权的树,问有多少对点的距离<=Len
Input
第一行两个整数N,Len(2<=n<=10000,len<=maxlongint)
接下来N-1行,每行3个整数,x,y,l,表示x和y有一条边长为l的边
Output
一行,一个整数ans,表示答案
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
Sample Output
8
Solution
总体方法:点分治
这道题要求树的重心。
树的重心定义:已该点为根,则它的所有子树的最大节点最小。
那么我们从树的中心出发,以它为根节点到达叶节点的深度一定是最小的。
设现在正在处理以x为根节点的子树,设dis[y(y为x子树中的一点)]为y点到x点的距离。
那么如何求点对呢?首先我们将dis从小到大排个序,定义l,r两个指针,一个指向头,一个指向尾。如下图。
如果dis[l]+dis[r]>len则r–,否则将点对数增加r-l对,并且l++。(why?因为如果dis[l]+dis[r]合法,那么dis[l]+dis[r-1],dis[l]+dis[r-2]到dis[l]+dis[l+1]都合法)
那么问题来了。现在的点对有两种情况:一种是经过点x的,一种是不经过点x的,不经过点x的就要统统删掉,因为不经过x点的点对一定会再x的儿子的子树中被计算过,重复了,所以要删掉。
然后就输出答案了。
Code
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
using namespace std;
int head[20010],next[20010],go[20010],val[20010],dis[10010],size[10010],mx[10010],order[10010];
bool bz[20010];
int i,j,k,len,n,tot,x,y,z,hv,od,ans;
void lb(int x,int y,int z)
{
go[++tot]=y;
next[tot]=head[x];
head[x]=tot;
val[tot]=z;
}
void zheavy(int set,int x,int y)
{
int i;
mx[x]=0;size[x]=1;
for (i=head[x];i;i=next[i])
{
if (go[i]!=y && (!bz[go[i]]))
{
zheavy(set,go[i],x);
size[x]+=size[go[i]];
mx[x]=max(mx[x],size[go[i]]);
}
}
mx[x]=max(mx[x],set-size[x]);
if (mx[x]<mx[hv]) hv=x;
return;
}
void makedis(int x,int y)
{
int i,now;
order[++od]=x;
for (i=head[x];i;i=next[i])
{
now=go[i];
if (now!=y && (!bz[now]))
{
dis[now]=dis[x]+val[i];
makedis(now,x);
}
}
}
void qsort(int l,int r)
{
int i=l,j=r,mid=order[(l+r)/2];
while (i<j)
{
while (dis[order[i]]<dis[mid]) i++;
while (dis[order[j]]>dis[mid]) j--;
if (i<=j)
{
swap(order[i],order[j]);
i++;j--;
}
}
if (l<j) qsort(l,j);
if (i<r) qsort(i,r);
}
int qdd()
{
qsort(1,od);
int l=1,r=od,sum=0;
while (l<r)
{
if (dis[order[l]]+dis[order[r]]>len) r--;
else
{
sum+=r-l;
l++;
}
}
return sum;
}
void dg(int x,int y)
{
int i,now;
bz[x]=1;dis[x]=0;od=0;
makedis(x,y);
ans+=qdd();
for (i=head[x];i;i=next[i])
{
now=go[i];
if (!bz[now])
{
dis[now]=val[i];
od=0;
makedis(now,x);
ans-=qdd();
hv=0;
zheavy(size[now],now,x);
dg(hv,x);
}
}
}
int main()
{
scanf("%d%d",&n,&len);
for (i=1;i<=n-1;i++)
{
scanf("%d%d%d",&x,&y,&z);
lb(x,y,z);
lb(y,x,z);
}
mx[0]=2147438647;
zheavy(n,1,0);
dg(hv,0);
printf("%d",ans);
}
Special
这道题涉及到多种知识,是一道好题。
——2016.6.16