树分治论文:https://wenku.baidu.com/view/8861df38376baf1ffc4fada8.html?re=view
树分治讲解:https://blog.csdn.net/qq_31759205/article/details/75579558
题目链接:https://cn.vjudge.net/problem/POJ-1741
题意:两点距离小于K的对数
题解:通常就是用点分治,有复杂度的保证
1、找出树的重心,子树节点最大值最小
2、将树的重心作为根节点root,计算树中每个点到root的距离dis
3、计算树中所有满足dis[u]+dis[v]<=k的点对数cnt1
4、计算以root的子节点为根的子树中,满足dis[u]+dis[v]<=k的点对数cnt2
5、ans+=cnt1-cnt2
6、删掉节点root,分别遍历root的子树,回到第1步
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
const int N = 1e4 + 10;
struct EDGE {
int to, d, nex;
}e[N * 2];
int n, k, pnum;
int head[N], len;
int vis[N];
int son[N], max_son[N];
void Init() {
for(int i = 0 ; i <= n; i++) {
head[i] = -1;
vis[i] = 0;
}
len = 0;
}
void AddEdge(int x, int y, int z) {
e[len].to = y;
e[len].d = z;
e[len].nex = head[x];
head[x] = len++;
}
void getroot(int u, int fa, int &root, int &minn) {
son[u] = 1;
max_son[u] = 0;
int to;
for(int i = head[u]; i != -1; i = e[i].nex) {
to = e[i].to;
if(to == fa || vis[to]) continue;
getroot(to, u, root, minn);
son[u] += son[to];
max_son[u] = max(max_son[u], son[to]);
}
max_son[u] = max(max_son[u], pnum - son[u]);
if(max_son[u] < minn) {
minn = max_son[u];
root = u;
}
}
vector<int> dis;
void getdis(int u, int fa, int d) {
dis.push_back(d);
int to;
for(int i = head[u]; i != -1; i = e[i].nex) {
to = e[i].to;
if(to == fa || vis[to]) continue;
getdis(to, u, d + e[i].d);
}
}
int getnum(int u, int d) {
int res = 0;
dis.clear();
getdis(u, -1, d);
sort(dis.begin(), dis.end());
int i = 0, j = dis.size() - 1;
while(i < j) {
while(dis[i] + dis[j] > k && i < j) j--;
res += j - i;
i++;
}
return res;
}
int dfs(int u) {
int minn = N, root;
getroot(u, -1, root, minn);
vis[root] = 1;
int res = 0;
res += getnum(root, 0);
int to;
for(int i = head[root]; i != -1; i = e[i].nex) {
to = e[i].to;
if(vis[to]) continue;
pnum = son[to];
res -= getnum(to, e[i].d);
res += dfs(to);
}
return res;
}
int main() {
int x, y, z = 1;
while(~scanf("%d %d", &n, &k) && (n || k)) {
Init();
for(int i = 1; i < n; i++) {
scanf("%d %d %d", &x, &y, &z);
AddEdge(x, y, z);
AddEdge(y, x, z);
}
pnum = n;
printf("%d\n", dfs(1));
}
return 0;
}
当然也可以先处理子代,然后保存下来一块处理
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <map>
using namespace std;
const int N = 1e4 + 10;
struct EDGE {
int to, d, nex;
}e[N * 2];
int n, k, pnum;
int head[N], len;
int vis[N];
int son[N], max_son[N];
int pval[N], cnt;
void Init() {
for(int i = 0 ; i <= n; i++) {
head[i] = -1;
vis[i] = 0;
}
len = 0;
}
void AddEdge(int x, int y, int z) {
e[len].to = y;
e[len].d = z;
e[len].nex = head[x];
head[x] = len++;
}
void getroot(int u, int fa, int &root, int &minn) {
son[u] = 1;
max_son[u] = 0;
int to;
for(int i = head[u]; i != -1; i = e[i].nex) {
to = e[i].to;
if(to == fa || vis[to]) continue;
getroot(to, u, root, minn);
son[u] += son[to];
max_son[u] = max(max_son[u], son[to]);
}
max_son[u] = max(max_son[u], pnum - son[u]);
if(max_son[u] < minn) {
minn = max_son[u];
root = u;
}
}
int dis[N], tot;
void getdis(int u, int fa, int d) {
dis[tot++] = d;
int to;
for(int i = head[u]; i != -1; i = e[i].nex) {
to = e[i].to;
if(to == fa || vis[to]) continue;
getdis(to, u, d + e[i].d);
}
}
int res;
void dfs(int u) {
int minn = N, root;
getroot(u, -1, root, minn);
vis[root] = 1;
int to, l, r;
tot = 0;
dis[tot++] = 0;
for(int i = head[root]; i != -1; i = e[i].nex) {
to = e[i].to;
if(vis[to]) continue;
cnt = 0;
l = tot;
getdis(to, root, e[i].d);
r = tot - 1;
sort(dis + l, dis + r + 1);
while(l < r) {
while(l < r && dis[l] + dis[r] > k) r--;
res -= r - l;
l++;
}
}
l = 0, r = tot - 1;
sort(dis, dis + tot);
while(l < r) {
while(l < r && dis[l] + dis[r] > k) r--;
res += r - l;
l++;
}
for(int i = head[root]; i != -1; i = e[i].nex) {
to = e[i].to;
if(vis[to]) continue;
pnum = son[to];
dfs(to);
}
}
int main() {
int x, y, z = 1;
while(~scanf("%d %d", &n, &k) && (n || k)) {
Init();
for(int i = 1; i < n; i++) {
scanf("%d %d %d", &x, &y, &z);
AddEdge(x, y, z);
AddEdge(y, x, z);
}
pnum = n;
res = 0;
dfs(1);
printf("%d\n", res);
}
return 0;
}
然后给出一个T的错误示范:这里我是想保存之前的子树然后按照子树一个整体计算,但是想一下复杂度,dis是累计保存的节点,所以复杂度起码在乘上log(n),这样就会T了
for(int i = head[root]; i != -1; i = e[i].nex) {
to = e[i].to;
if(vis[to]) continue;
cnt = 0;
getdis(to, root, e[i].d);
sort(pval, pval + cnt);
sort(dis, dis + tot);
int l = 0, r = tot - 1;
while(l < cnt && r >= 0) {
while(r >= 0 && pval[l] + dis[r] > k) r--;
if(r >= 0) res += r + 1;
l++;
}
for(int j = 0; j < cnt; j++)
dis[tot++] = pval[j];
}