题目
题意
求树上两点最短距离不超过 K 的对数。
题解
通过树的分治的方法,从根节点自上而下开始求其不同子树上的两点通过该点的最短距离小于 K 的对数。
有两点需要注意:
1. 求某节点不同子树满足条件的两点不好直接求。可以通过先求该节点的所有子节点满足条件的个数,然后减去其子树内所有满足条件的点的个数。这个通过 dfs 很容易实现。
2. 在分治的过程中,在找某子树满足条件的点对数的时候,可以从树的重心开始找,这样可以避免算法时间复杂度退化。
代码
#include <algorithm>
#include <bitset>
#include <cstring>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <climits>
#include <iostream>
#include <list>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <string>
#include <vector>
using namespace std;
const int MAX = 10005;
struct node{
int first,second;
};
vector<struct node>G[MAX];
int vis[MAX];
int getSize(int u,int fa){
int len = G[u].size();
int sum = 1;
for(int i=0;i<len;++i){
int v = G[u][i].first;
if(v != fa && vis[v] == 0){
sum += getSize(v,u);
}
}
return sum;
}
int getRoot(int u, int fa,int &root, int &mx, int n) {
int len = G[u].size();
int sum = 1;
int mx1 = 0;
for(int i=0;i<len;++i){
int v = G[u][i].first;
if(v != fa && vis[v] == 0){
int temp = getRoot(v,u,root,mx,n);
mx1= max(mx1, temp);
sum += temp;
}
}
mx1 = max(mx1, (n - sum));
if(mx1 < mx){
root = u;
mx = mx1;
}
return sum;
}
int sum;
int N[MAX];
int num;
int n,k;
void getDeep(int u, int fa, int d){
N[num++] = d;
int len = G[u].size();
for(int i=0;i<len;++i){
int v = G[u][i].first;
if(vis[v] == 0 && v != fa)
getDeep(v,u, d + G[u][i].second);
}
}
int calu(int u, int d){
num = 0;
getDeep(u,0,d);
sort(N, N + num);
int i = 0,j = num - 1, tot = 0;
while(i < j){
while(i < j && N[i] + N[j] > k) j--;
tot += (j - i);
i++;
}
return tot;
}
void dfs(int u){
int root = 0;
int mx = INT_MAX;
int n = getSize(u,0);
getRoot(u,0,root,mx,n);
vis[root] = 1;
sum += calu(root,0);
int len = G[root].size();
for(int i = 0;i < len;++i){
int v = G[root][i].first;
if(vis[v] == 0){
sum -= calu(v,G[root][i].second);
dfs(v);
}
}
}
void init(int n) {
for (int i = 1; i <= n; ++i) {
G[i].clear();
}
sum = 0;
memset(vis,0,sizeof(vis));
}
int main(){
while(scanf("%d %d",&n,&k) && n != 0 && k != 0 ){
init(n);
int u,v,l;
for (int i = 0; i < n - 1; ++i) {
scanf("%d %d %d",&u,&v,&l);
struct node temp;
temp.first = v;
temp.second = l;
G[u].push_back(temp);
temp.first = u;
G[v].push_back(temp);
}
dfs(1);
printf("%d\n",sum);
}
return 0;
}