点分治
1.算法分析
点分治适合处理大规模的树上路径信息问题。
所谓点分治,一般的思想如下:我们先随意选择一个节点作为根节点 r t rt rt,所有完全位于其子树中的路径可以分为两种,一种是经过当前根节点的路径,一种是不经过当前根节点的路径。对于经过当前根节点的路径,又可以分为两种,一种是以根节点为一个端点的路径,另一种是两个端点都不为根节点的路径。而后者又可以由两条属于前者链合并得到。所以,对于枚举的根节点 r t rt rt,我们先计算在其子树中且经过该节点的路径对答案的贡献,再递归其子树对不经过该节点的路径进行求解。
2.模板
// 统计有多少对点之间的距离小于等于k的
// 模板注意点:距离为m,这个修改了太多参数都要修改;计算距离的getdis()进行了优化:if(dis[u] <= m),有的时候不适用,需要修改或者删除
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e4 + 5, MAXM = MAXN * 2, MAXK = 5e6 + 10;
int h[MAXN], e[MAXM], ne[MAXM], idx, w[MAXM];
//? rt记录重心,sum记录当前树大小,cnt是计数器
int n, rt, sum, cnt, m;
//? tmp记录算出的距离,siz记录子树大小,dis[i]为rt与i之间的距离
//? maxp用于找重心:maxp[1]以1为根的树内的最大子树大小,q用于记录所有询问
int tmp[MAXN], siz[MAXN], dis[MAXN], maxp[MAXN], q[MAXN], p[MAXN];
//?
// judge[i]记录在之前子树中距离i是否存在,ans记录第k个询问是否存在,vis记录被删除的结点:vis[1]=1表示1点被删掉了
bool vis[MAXN];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
// TODO 找重心
void getrt(int u, int fa) {
siz[u] = 1, maxp[u] = 0; // maxp初始化为最小值
//遍历所有儿子,用maxp保存最大大小的儿子的大小
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa || vis[j]) continue; //被删掉的也不要算
getrt(j, u);
siz[u] += siz[j];
if (siz[j] > maxp[u]) maxp[u] = siz[j]; //更新maxp
}
maxp[u] = max(maxp[u], sum - siz[u]); //考虑u的祖先结点
if (maxp[u] < maxp[rt]) rt = u; //更新重心(最大子树大小最小)
}
// TODO 计算各结点与根结点之间的距离并全部记录在tmp里
void getdis(int u, int fa) {
if(dis[u] <= m) tmp[cnt++] = dis[u]; //如果大于k就没有必要了再对cnt进行更新了(q[i]<=m)
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa || vis[j]) continue;
dis[j] = dis[u] + w[i];
getdis(j, u);
}
}
int get(int a[], int n) {
sort(a, a + n);
int res = 0;
for (int i = n - 1, j = -1; i >= 0; i--) {
while (j + 1 < i && a[j + 1] + a[i] <= m) j++;
j = min(j, i - 1);
res += j + 1;
}
return res;
}
// TODO 处理经过根结点的路径
// 计算经过u点的路径对答案的贡献,注意judge数组要存放之前子树里存在的路径长度,排除折返路径的可能
int solve(int u) {
int res = 0;
int pt = 0;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (vis[j]) continue;
cnt = 0; //注意置零计数器
dis[j] = w[i];
getdis(j, u); //把距离都处理出来
res -= get(tmp, cnt); // 容斥原理,先减掉不经过k的点对
for (int j = 0; j < cnt; j++) {
//把存在的单条路径长度标上true,供下个子树用
if (tmp[j] <= m) res++; // 加上和根节点形成的点对
p[pt++] = tmp[j];
}
}
res += get(p, pt); // 然后全部求点对
return res;
}
// TODO 分治
int divide(int u) {
int res = 0;
vis[u] = true; //删除根结点
res += solve(u); //计算经过根结点的路径
for (int i = h[u]; ~i; i = ne[i]) {
//分治剩余部分
int j = e[i];
if (vis[j]) continue;
maxp[rt = 0] = sum = siz[j]; //把重心置为0,并把maxp[0]置为最大值
getrt(j, 0);
getrt(rt, 0); //与主函数相同,第二次更新siz大小
res += divide(rt); // 递归计算每个子树对答案的贡献
}
return res;
}
int main() {
while(scanf("%d%d", &n, &m) != EOF) {
if (!n && !m) break;
for (int i = 1; i <= n; ++i) h[i] = -1, vis[i] = false;
idx = 0;
for (int i = 1; i < n; i++) {
// 建边
int u, v, w;
cin >> u >> v >> w;
u++, v++; // 这里的点从0开始,加一变为从1开始
add(u, v, w), add(v, u, w);
}
maxp[0] = sum = n; // maxp[0]置为最大值(一开始rt=0)
getrt(1, 0); //找重心
getrt(rt, 0); //! 此时siz数组存放的是以1为根时的各树大小,需要以找出的重心为根重算
cout << divide(rt) << endl; //找好重心就可以开始分治了
}
return 0;
}
3.典型例题
CF161D Distance in Tree
题意: 求树上距离为k的点的对数
题解: 维护一个exist数组,exist[i]表示距离为i的点出现的次数。每次搜索完一颗子树后,就枚举所有的距离,然后加上exist内出现的次数
代码:
// 统计有多少对点之间的距离小于等于k的
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 5e4 + 5, MAXM = MAXN * 2, MAXK = 5e2 + 10;
int h[MAXN], e[MAXM], ne[MAXM], idx, w[MAXM];
//? rt记录重心,sum记录当前树大小,cnt是计数器
int n, rt, sum, cnt, m;
//? tmp记录算出的距离,siz记录子树大小,dis[i]为rt与i之间的距离
//? maxp用于找重心:maxp[1]以1为根的树内的最大子树大小,q用于记录所有询问
int tmp[MAXN], siz[MAXN], dis[MAXN], maxp[MAXN], p[MAXN], exist[MAXK];
//?
// judge[i]记录在之前子树中距离i是否存在,ans记录第k个询问是否存在,vis记录被删除的结点:vis[1]=1表示1点被删掉了
bool vis[MAXN];
void add(int a, int b, int c) {
e[idx] = b, w[idx] = c, ne[idx] = h[a], h[a] = idx++;
}
// TODO 找重心
void getrt(int u, int fa) {
siz[u] = 1, maxp[u] = 0; // maxp初始化为最小值
//遍历所有儿子,用maxp保存最大大小的儿子的大小
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa || vis[j]) continue; //被删掉的也不要算
getrt(j, u);
siz[u] += siz[j];
if (siz[j] > maxp[u]) maxp[u] = siz[j]; //更新maxp
}
maxp[u] = max(maxp[u], sum - siz[u]); //考虑u的祖先结点
if (maxp[u] < maxp[rt]) rt = u; //更新重心(最大子树大小最小)
}
// TODO 计算各结点与根结点之间的距离并全部记录在tmp里
void getdis(int u, int fa) {
if(dis[u] <= m) tmp[cnt++] = dis[u]; //如果大于k就没有必要了再对cnt进行更新了(q[i]<=m)
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa || vis[j]) continue;
dis[j] = dis[u] + w[i];
getdis(j, u);
}
}
// TODO 处理经过根结点的路径
// 计算经过u点的路径对答案的贡献
int solve(int u) {
int res = 0;
int pt = 0;
exist[0] = 1; // 因为任何一个点都可以和根节点组对,因此根节点的距离要初始化
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (vis[j]) continue;
cnt = 0; //注意置零计数器
dis[j] = w[i];
getdis(j, u); //把距离都处理出来
for (int j = 0; j < cnt; j++) {
//把存在的单条路径长度标上true,供下个子树用
res += exist[m - tmp[j]];
p[pt++] = tmp[j];
}
for (int j = 0; j < cnt; ++j) exist[tmp[j]]++;
}
for (int i = 0; i < pt; ++i) exist[p[i]] = 0;
return res;
}
// TODO 分治
int divide(int u) {
int res =