题目链接: http://poj.org/problem?id=1741
题意:
给一棵树,树上每条边都有权值,问你树上两个点之间的距离不超过K的点对有多少个。
思路:
看2009年国家集训队 漆子超 的论文 <<分治算法在树的路径问题中的应用>>。
文中已经把大概的解题思路都将清楚了,但是作为第一次写树分治的我来说依旧是茫然一片。。。
于是只能依靠网上的大牛了。。
这题融汇了很多的思想,细细品味后感觉收货还是挺大的。
总而言之,大体思路就是首先找当前树的重心作为树的根节点,然后计算该树中每个节点到根节点的距离并且保存下来然后排序,排完序后就可以用左右逼近的方法计算出有多少点对到根的距离之和小于等于给定值,但是还要减去两个点在同一个子树中的情况,再加上子树中满足条件的点对,子树中的情况可以递归处理。
解释一下重心,一棵树的重心就是去掉该点之后所有子树节点数的最大值最小。具体实现在代码中讲得很详细了。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
#include <map>
#include <set>
#include <cmath>
#include <algorithm>
#include <functional>
#include <cmath>
#include <bitset>
using namespace std;
const int inf = 1<<28;
const int maxn = 10010;
struct Edge{
int to,next,w;
}e[maxn<<1];
int n,k;
int head[maxn],cnt;
int sum[maxn],Max[maxn];
int tot;
bool vis[maxn];
vector<int> V1,V2;
void init(){
cnt=0;
memset(head,-1,sizeof(head));
memset(vis,false,sizeof(vis));
}
void add(int u,int v,int w){
e[cnt].to = v;
e[cnt].w = w;
e[cnt].next = head[u];
head[u] = cnt++;
}
// 遍历子树,同时保存最大子树和节点标号
void dfs1(int u,int fa){
sum[u] = 1;
Max[u] = 0;
tot++;
for(int i=head[u]; i!=-1; i=e[i].next){
int v = e[i].to;
if(v == fa || vis[v]) continue;
dfs1(v,u);
sum[u] += sum[v];
Max[u] = max(Max[u],sum[v]);
}
V1.push_back(Max[u]);
V2.push_back(u);
}
//找重心, Max保存以每个节点为根的最大子树的值
// V1保存最大值,V2保存节点标号
int getWP(int u){
V1.clear(), V2.clear();
tot = 0;
dfs1(u,-1);
int Min = inf, res, temp;
for(int i=0;i<V1.size();i++){
temp = max(V1[i],tot-V1[i]); // 注意!
if(temp < Min) {
Min = temp;
res = V2[i];
}
}
return res;
}
// 计算每个节点到根节点的距离保存到vector中排序
void getDis(int u,int fa,int d){
V1.push_back(d);
for(int i=head[u]; i!=-1; i=e[i].next){
int v = e[i].to;
if(v == fa || vis[v]) continue;
getDis(v,u,d+e[i].w);
}
}
// 计算以u为根,且u到它父节点的距离为d,u的子树中符合条件的节点对数
int calc(int u,int d){
int ret = 0;
V1.clear();
getDis(u,-1,d);
sort(V1.begin(),V1.end());
int l = 0, r = V1.size()-1;
while(l < r){
if(V1[r] + V1[l] <= k){
ret += (r-l);
l++;
}else r--;
}
return ret;
}
// 树的点分治
// 原理:每次以重心为根递归下去,可以证明最多递归O(logN)次
int dfs(int u){
int ret = 0;
int rt = getWP(u); //找重心减少递归次数
//cout <<"WP:" <<u << " " << rt << endl;
ret += calc(rt,0); // 首先计算子树中所有符合条件的点对
vis[rt] = true; // 标记一个点就相当于将树分成了几个子树
for(int i=head[rt]; i!=-1; i=e[i].next){
int v = e[i].to;
if(!vis[v]){
ret -= calc(v,e[i].w); //减去在同一个子树中的点对
ret += dfs(v); //加上同一个子树中也有成立的
}
}
return ret;
}
int main(){
int u,v,w;
while(cin >> n >> k){
if(n==0 && k==0) break;
init();
for(int i=1;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);
add(v,u,w);
}
cout << dfs(1) << endl;
}
return 0;
}