POJ P1741 Tree
题目
Tree
Time Limit: 1000MS Memory Limit: 30000K
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
题目大意
给出一个有n个顶点的树,每条边都有一个长度(正整数小于1001)。
定义dist(u,v)=节点u和v之间的最小距离。
给出一个整数k,对每对(u,v)顶点都是有效的,如果且仅当dist(u,v)不超过k。
编写一个程序,该程序将计算对给定树有效的对数对数。
输入
输入包含几个测试用例。每个测试用例的第一行包含两个整数n,k。(n < = 10000)下面的n - 1行每个包含三个整数u,v,l,这意味着节点u和长度l的v之间有一条边。
最后一个测试用例后面是两个0。
输出
对于每个测试用例输出一行,为dist(u,v)不超过k的u,v的对数。
题解
点分治
代码
#include<cstdio>
#include<iostream>
#include<cstdlib>
#define maxn 10005
#define INF 0x7fffffff
using namespace std;
int n,m,tot,rt,size,ans;
int lnk[maxn],d[maxn],vis[maxn],f[maxn],son[maxn],dp[maxn];
struct edge{
int v,y,nxt;
} e[maxn*2];
int readln()
{
int x=0;
char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while ('0'<=ch&&ch<='9') x=x*10+ch-48,ch=getchar();
return x;
}
void insert(int x,int y,int v)
{
tot++;
e[tot].nxt=lnk[x];
lnk[x]=tot;
e[tot].v=v;
e[tot].y=y;
}
void getrt(int x,int fa)
{
f[x]=0;son[x]=1;
for (int i=lnk[x];i;i=e[i].nxt)
{
if (e[i].y==fa||vis[e[i].y]) continue;
getrt(e[i].y,x);
son[x]+=son[e[i].y];
f[x]=max(f[x],son[e[i].y]);
}
f[x]=max(f[x],size-son[x]);
if (f[x]<f[rt]) rt=x;
}
void getdp(int x,int fa)
{
dp[0]++;
dp[dp[0]]=d[x];
for (int i=lnk[x];i;i=e[i].nxt)
{
if (e[i].y==fa||vis[e[i].y]) continue;
d[e[i].y]=d[x]+e[i].v;
getdp(e[i].y,x);
}
}
void qsort(int l,int r)
{
int i=l,j=r,mid=dp[rand()%(r-l+1)+l],t;
do {
while (dp[i]<mid) i++;
while (dp[j]>mid) j--;
if (i<=j) {
t=dp[i];dp[i]=dp[j];dp[j]=t;
i++;j--;
}
} while (i<=j);
if (i<r) qsort(i,r);
if (l<j) qsort(l,j);
}
int cal(int x,int v)
{
d[x]=v;dp[0]=0;
getdp(x,0);
qsort(1,dp[0]);
int l=1,r=dp[0],ans=0;
while (l<=r)
if (dp[l]+dp[r]<=m) {
ans+=r-l;
l++;
}
else r--;
return ans;
}
void solve(int x)
{
ans+=cal(x,0);
vis[x]=1;
for (int i=lnk[x];i;i=e[i].nxt)
{
if (vis[e[i].y]) continue;
ans-=cal(e[i].y,e[i].v);
rt=0;
size=son[e[i].y];
getrt(e[i].y,0);
solve(rt);
}
}
int main()
{
while (true)
{
ans=0;rt=0;tot=0;size=0;f[0]=INF;
for (int i=1;i<=maxn;i++)
{
vis[i]=0;
lnk[i]=0;
}
n=readln();m=readln();
if (n==0&&m==0) break;
for (int i=1;i<n;i++)
{
int x=readln(),y=readln(),v=readln();
insert(x,y,v);insert(y,x,v);
}
getrt(1,0);
solve(rt);
printf("%d\n",ans);
}
}