Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
Source
LouTiancheng@POJ
问你树上有多少点对满足距离小于等于k。
树分治裸题。答案分为经过根节点的路径和不经过根节点的路径。每次我们只计算经过根节点的路径条数。
可以通过统计所有点到根节点的距离计算。存到d数组里,问题成了计算d[i] + d[j] <= k
的 < i,j >个数。
排序( O(nlogn) )后可以在O(n)的时间内算出,具体看代码(getans函数)。得到答案ans1。
容易想到,ans1有不合法的路径,那就是这条路径起点终点在同一个子树内。这时我们应该减去这些方案数。解决方法就是,每次计算某个儿子为根的子树时,计算完毕后减去这棵子树的答案即可。
分治nlogn,加上sort,总复杂度是O(nlogn)
我代码在poj上老是TLE,刷了快半页了…挖个坑,以后再说。
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int SZ = 10010;
int head[SZ],nxt[SZ << 1],n,k,tot = 0,dist[SZ];
struct edge{
int t,d;
}l[SZ << 1];
void build(int f,int t,int d)
{
l[++ tot].t = t;
l[tot].d = d;
nxt[tot] = head[f];
head[f] = tot;
}
int ans = 0,maxn,s,t,root;
bool rt[SZ];
int find(int u,int fa)
{
int sz = 1;
int now = 0;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v] && v != fa)
{
int son = find(v,u);
sz += son;
now = max(now,son);
}
}
now = max(now,n - sz);
if(now < maxn) maxn = now,root = u;
return sz;
}
void dfsdist(int u,int fa,int d)
{
dist[++ t] = d;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v] && v != fa)
dfsdist(v,u,d + l[i].d);
}
}
int getans(int s,int t)
{
sort(dist + s,dist + t + 1);
int ans = 0;
int r = t;
for(int i = s;i <= t;i ++)
{
while(dist[i] + dist[r] > k && r > i) r --;
ans += r - i;
if(r == i) break;
}
return ans;
}
void dfs(int x,int fa)
{
maxn = n;
find(x,fa);
int u = root;
s = 1,t = 0;
rt[u] = 1;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v])
{
s = t + 1;
dfsdist(v,u,l[i].d);
ans -= getans(s,t);
}
}
dist[++ t] = 0;
ans += getans(1,t);
for(int i = head[u];i;i = nxt[i])
if(!rt[l[i].t]) dfs(l[i].t,u);
}
void init()
{
memset(head,0,sizeof(head));
memset(rt,0,sizeof(rt));
ans = tot = 0;
}
void scanf(int &n)
{
n = 0;
char a = getchar();
bool flag = 0;
while(a < '0' || a > '9') { if(a == '-') flag = 1; a = getchar(); }
while(a >= '0' && a <= '9') n = (n << 3) + (n << 1) + a - '0',a = getchar();
if(flag) n = -n;
}
int main()
{
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
while(233)
{
init();
scanf(n); scanf(k);
if(!n && !k) break;
for(int i = 1,a,b,c;i < n;i ++)
{
scanf(a); scanf(b); scanf(c);
build(a,b,c);
build(b,a,c);
}
dfs(1,0);
printf("%d\n",ans);
}
return 0;
}
———-以上是12.28号的事情———-
———-以下是12.29号的事情———-
填坑了…树分治导致我树的重心打错,导致昨天T了三个题……
树的重心需要用到当前树的总点数,我直接用的n,所以重心找错了…
所以说要在找重心之前先计算一下当前树的大小,然后就可以AC了。
代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<algorithm>
using namespace std;
const int SZ = 10010;
int head[SZ],nxt[SZ << 1],n,k,tot = 0,dist[SZ];
struct edge{
int t,d;
}l[SZ << 1];
void build(int f,int t,int d)
{
l[++ tot].t = t;
l[tot].d = d;
nxt[tot] = head[f];
head[f] = tot;
}
int ans = 0,maxn,s,t,root;
bool rt[SZ];
int find(int u,int fa,int n)
{
int sz = 1;
int now = 0;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v] && v != fa)
{
int son = find(v,u,n);
sz += son;
now = max(now,son);
}
}
now = max(now,n - sz);
if(now < maxn) maxn = now,root = u;
return sz;
}
void dfsdist(int u,int fa,int d)
{
dist[++ t] = d;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v] && v != fa)
dfsdist(v,u,d + l[i].d);
}
}
int getans(int s,int t)
{
sort(dist + s,dist + t + 1);
int ans = 0;
int r = t;
for(int i = s;i <= t;i ++)
{
while(dist[i] + dist[r] > k && r > i) r --;
ans += r - i;
if(r == i) break;
}
return ans;
}
int dfssz(int u,int fa)
{
int sz = 1;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v] && v != fa)
sz += dfssz(v,u);
}
return sz;
}
void dfs(int x,int fa)
{
int sz = dfssz(x,fa);
maxn = n;
find(x,fa,sz);
int u = root;
s = 1,t = 0;
rt[u] = 1;
for(int i = head[u];i;i = nxt[i])
{
int v = l[i].t;
if(!rt[v])
{
s = t + 1;
dfsdist(v,u,l[i].d);
ans -= getans(s,t);
}
}
dist[++ t] = 0;
ans += getans(1,t);
for(int i = head[u];i;i = nxt[i])
if(!rt[l[i].t]) dfs(l[i].t,u);
}
void init()
{
memset(head,0,sizeof(head));
memset(rt,0,sizeof(rt));
ans = tot = 0;
}
void scanf(int &n)
{
n = 0;
char a = getchar();
bool flag = 0;
while(a < '0' || a > '9') { if(a == '-') flag = 1; a = getchar(); }
while(a >= '0' && a <= '9') n = (n << 3) + (n << 1) + a - '0',a = getchar();
if(flag) n = -n;
}
int main()
{
// freopen("in.txt","r",stdin);
// freopen("out.txt","w",stdout);
while(233)
{
init();
scanf(n); scanf(k);
if(!n && !k) break;
for(int i = 1,a,b,c;i < n;i ++)
{
scanf(a); scanf(b); scanf(c);
build(a,b,c);
build(b,a,c);
}
dfs(1,0);
printf("%d\n",ans);
}
return 0;
}