比较基础的题
【题解】
无根树转有根树再考虑,对于根结点:
树上任意两点的路径有两种情况,即 经过根结点和不经过根结点
1. 要求经过根结点的,求出所有点到根的距离d[x],将所有d[i]+d[j]<=k的点对(i,j)的数量加入答案中。
但实际上,若(i,j)是来自同一棵子树的,它们的路径不经过根,不该被计数,把这样的点对数减去即可
2. 要求不经过根节点的,递归根的子结点,执行情况1即可
这个根节点不能随便给,每次需要求树的重心
树的重心:这个结点包含子树的size值(结点个数)中最大的尽量小
可以证明重心的所有子树的size值均<=n/2
【代码】
细节比较多
有个地方不理解:
每次getroot()以后,因为原树中root的父结点会变为root的子结点,那么需不需要把原来树中root的父结点的size值更新为sum-size[root]?
我觉得需要,但下面的代码没有改,也AC
不过树的重心求偏并不会致错,实际上更不更新size[fa]耗时也的确差不多(hzwer也是这样说的)
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
int vis[10005]={0},size[10005]={0},maxsize[10005]={0},d[10005]={0},a[10005]={0};
int v[20005]={0},w[20005]={0},first[20005]={0},next[20005]={0};
int k,e,sum,root,p;
int max(int a,int b)
{
if(a>b) return a;
return b;
}
void tj(int x,int y,int z)
{
v[++e]=y;
w[e]=z;
next[e]=first[x];
first[x]=e;
}
void kp(int low,int high)
{
int i=low,j=high,mid=a[(i+j)/2],t;
while(i<j)
{
while(a[i]<mid) i++;
while(a[j]>mid) j--;
if(i<=j)
{
t=a[i];
a[i]=a[j];
a[j]=t;
i++;
j--;
}
}
if(j>low) kp(low,j);
if(i<high) kp(i,high);
}
void getroot(int x,int fa)
{
int i;
size[x]=1;
maxsize[x]=0;
for(i=first[x];i!=0;i=next[i])
if(vis[v[i]]==0&&v[i]!=fa)//"vis[v[i]]==0&&"不可删:从root出发可能走到之前"大树"的root,应避免
{
getroot(v[i],x);
size[x]+=size[v[i]];
maxsize[x]=max(maxsize[x],size[v[i]]);
}
maxsize[x]=max(maxsize[x],sum-size[x]);
if(root==0||maxsize[x]<maxsize[root]) root=x;
}
void getd(int x,int fa)
{
int i;
a[++p]=d[x];
for(i=first[x];i!=0;i=next[i])
if(vis[v[i]]==0&&v[i]!=fa)//vis[v[i]]==0:从root出发可能走到之前"大树"的root,应避免; v[i]!=fa:不能走向直接祖先
{
d[v[i]]=d[x]+w[i];
getd(v[i],x);
}
}
int getcnt(int x,int now)
{
int left=1,right,cnt=0;
p=0;
d[x]=now;
getd(x,0);
kp(1,p);
for(right=p;left<=right;)
{
if(a[left]+a[right]<=k)
{
cnt+=right-left;
left++;
}
else right--;
}
return cnt;
}
int work(int x)
{
int i,ans=0;
vis[x]=1;
ans=getcnt(x,0);
for(i=first[x];i!=0;i=next[i])
if(vis[v[i]]==0)//"x!=fa"不对:fa是x所在子树的祖先,并非x的直接祖先
{
ans-=getcnt(v[i],w[i]);
root=0;//注意getroot前的初始化
sum=size[v[i]];//? 若v[i]为实际树(非root为根的树)中root的父亲,sum应为size[v[i]]-size[root]?
getroot(v[i],x);
ans+=work(root);
}
return ans;
}
int main()
{
int n,i,x,y,z;
while(scanf("%d%d",&n,&k)&&n!=0)
{
memset(first,0,sizeof(first));
memset(vis,0,sizeof(vis));
e=0;
for(i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
tj(x,y,z);
tj(y,x,z);
}
root=0;
sum=n;
getroot(1,0);
printf("%d\n",work(root));
}
return 0;
}