Tree
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v. Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. Write a program that will count how many pairs which are valid for a given tree. Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros. Output
For each test case output the answer on a single line.
Sample Input 5 4 1 2 3 1 3 1 1 4 2 3 5 1 0 0 Sample Output 8 Source |
[Submit] [Go Back] [Status] [Discuss]
题解:树的点分治
树的分治算法
树的分治算法的两个常见形式:
基于点的分治:
首先选取一个点将无根树转为有根树,再递归处理每一颗以根结
点的儿子为根的子树。
基于边的分治:
在树中选取一条边,将原树分成两棵不相交的树,递归处理。
首先我们考虑如何选取点(边)。对于基于点的分治,我们选取一个点,要求将其删去后,结点最多的树的结点个数最小,这个点被称为“树的重心”。而基于边的分治,我们选取的边要满足所分离出来的两棵子树的结点个数尽量平均,这条边称为“中心边”。而对于这两个问题,都可以使用在树上的动态规划来解决,时间复杂度均为O(N),其中N为树的结点总数。对于树的分治算法来说,递归的深度往往决定着算法效率的高低.
定理 1:存在一个点使得分出的子树的结点个数均不大于N/2
定理 2:如果一棵树中每个点的度均不大于D,那么存在一条边使得分出的两棵子树的结点个数在[N/(D+1),N*D/(D+1)] (N>=2)
由定理 1 可得,在基于点的分治中每次我们都会将树的结点个数减少一半,因此递归深度最坏是O(logn)的,在树是一条链的时候达到上界。
典型应用:给定一棵n个结点的带权树,定义dis(u,v)为v u,两点间的最短路径长度,路径的长度定义为路径上所有边的权和。再给定一个K,如果对于不同的两个结点b a,,如果满足dis(b,a)<=k,则称 (b,a)为合法点对。
求合法点对个数。
一条路径要么经过当前节点,要么在该节点的一颗子树里。
记depth(i)表示点i到根结点的路径长度,belong(i)=x(X为根结点的某个儿子,且结点i在以X为根的子树内)。那么我们要统计的就是:
满足depth(i)+depth(j)<=k,belong(i)!=belong(j) 的合法点对(i,j)的个数
= 满足depth(i)+depth(j)<=k的(i,j)的个数-depth(i)+depth(j)<=k,belong(i)=belong(j) 的(i,j)的个数
而对于这两个部分,都是要求出满足AI+AJ<=K(i,j)的对数。
将 A 排序后利用单调性我们很容易得出一个O(N)的算法,所以我们可以用O(NlogN)的时间来解决这个问题。
综上,此题使用树的分治算法时间复杂度为O(Nlog^2N).
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 200000
using namespace std;
int m,n,tot,root,sum,f[N],d[N],ans;
int point[N],next[N],son[N],vis[N],deep[N],v[N],len[N];
void add(int x,int y,int k)
{
tot++; next[tot]=point[x]; point[x]=tot; v[tot]=y; len[tot]=k;
tot++; next[tot]=point[y]; point[y]=tot; v[tot]=x; len[tot]=k;
}
void getroot(int x,int fa)
{
son[x]=1; f[x]=0;
for (int i=point[x];i!=-1;i=next[i])
{
if (vis[v[i]]||v[i]==fa) continue;
getroot(v[i],x);
son[x]+=son[v[i]];
f[x]=max(f[x],son[v[i]]);
}
f[x]=max(f[x],sum-son[x]);
if (f[x]<f[root]) root=x;
}
void getdeep(int x,int fa)
{
deep[++deep[0]]=d[x];
for (int i=point[x];i!=-1;i=next[i])
{
if (vis[v[i]]||v[i]==fa) continue;
d[v[i]]=d[x]+len[i];
getdeep(v[i],x);
}
}
int cal(int x,int now)
{
d[x]=now; deep[0]=0;
getdeep(x,0);
sort(deep+1,deep+deep[0]+1);
int l=1,r=deep[0];
int t=0;
while (l<r)
{
if (deep[l]+deep[r]<=m) t+=r-l,l++;
else r--;
}
return t;
}
void work(int x)
{
ans+=cal(x,0);
vis[x]=1;
for (int i=point[x];i!=-1;i=next[i])
{
if (vis[v[i]]) continue;
ans-=cal(v[i],len[i]);
sum=son[v[i]];
root=0;
getroot(v[i],root);
work(root);
}
}
int main()
{
while(true)
{
scanf("%d%d",&n,&m);
if (!n&&!m) break;
tot=-1;
memset(point,-1,sizeof(point));
memset(next,-1,sizeof(next));
memset(vis,0,sizeof(vis));
for (int i=1;i<n;i++)
{
int x,y,k; scanf("%d%d%d",&x,&y,&k);
add(x,y,k);
}
sum=n; f[0]=1000000000; ans=0;
getroot(1,0);
work(root);
printf("%d\n",ans);
}
}