本篇博客参考:
基本概念
首先,什么是 启发式合并 ?
有人将其称为“优雅的暴力”,启发式合并就是在合并两个部分的时候,将内容少的部分合并至内容多的部分,减少合并的操作时间
树上启发式合并(dsu on tree) 可以被用来解决树上的 离线问题(请注意,必须要是离线问题,因为处理问题的顺序有讲究),特别是可以维护以每个点为根的子树中的信息
一般来说,对于查询以每个点为根的子树中的信息的问题,我们可以用树形dp来处理,但是如果每个点的信息不止一两个数字,而是很庞大的部分(比如说每个点所需要的信息都要多个map来存储),这样使用树形dp的空间复杂度将会非常庞大,而树上启发式合并可以用来解决这样的问题
代码实现
举个例子,比如说我们给出一棵树,树上的每个结点染色,现在我们需要统计以每个结点为根的子树中出现多少种颜色
最暴力的方法就是每个结点跑一次 dfs,用 cnt[]
数组存储每个颜色出现的次数,输出
但是很明显会T的很惨
这样的树中,我们首先计算2子树的信息,然后计算3子树的信息的时候我们又要把2子树清空,每计算一个新的子树都要把之前计算过的信息清空,根本存不下来信息啊
然后我们考虑一下怎么优化呢,父结点的信息和子结点相关,我们可以用子结点的信息更新父结点的信息,也就是,我们在计算1子树的所有结点信息时,假如4子树是234里最后一个被计算的,那我们算完4子树之后,可以不用清空 cnt
数组,反正我们计算1子树的时候还是要遍历4子树的,将4子树的信息全部保留,再加上前面23子树的信息就可以得到1子树的信息了
这个4子树应该怎么选择呢?换句话说,我们保留哪一个子树的信息不被删除呢?根据启发式合并的思想,保留最庞大的子树信息不动,就可以减少重复计算的次数了
在树链剖分时我们把树中结点最多的子树根结点叫做重子结点,也就是说,在树上启发式合并的过程中,我们需要先计算所有轻子结点的信息(每计算一个轻子结点之后都要删除这个结点对当前答案的影响),最后计算重子结点的信息(保留重子结点对当前答案的影响),然后再计算前面的轻子结点(这一次计算要保留结点对当前答案的影响)
用两遍 dfs 实现
下面是一些变量定义:
sz[u]
以 u 为根的子树的结点数量son[u]
结点 u 的重子结点col[u]
结点 u 的颜色L[u]
结点 u 的 dfs 序R[u]
以 u 为根的子树中结点 dfs 序的最大值id[u]
L 标号 u 对应的结点编号,有id[L[u]] == u
cnt[u]
颜色 u 的出现次数totcol
目前出现过的颜色个数
void dfs1(int u, int fa) // u: 当前结点 fa: 父结点
{
L[u] = ++ totdfn; // 更新u的dfs序
Node[totdfn] = u; // 更新dfs序的映射
sz[u] = 1; // 初始化子树大小为1
for (int i = 0; i < g[u].size(); i ++ )
{
int j = g[u][i]; // 子结点编号
if (j == fa) continue;
dfs1(j, u);
sz[u] += sz[j]; // 用子结点的sz更新父结点的sz
if (sz[j] > sz[son[u]]) son[u] = j; // 更新重子结点
}
R[u] = totdfn; // 更新当前子树中dfs序的最大值
}
void dfs2(int u, int fa, bool keep) // u: 当前结点 fa: 父结点 keep: 此次遍历计算的答案是否保留
{
// 计算轻子结点的答案
for (int i = 0; i < g[u].size(); i ++ )
{
int j = g[u][i]; // 子结点编号
if (j == fa || j == son[u]) continue; // 遇到重子结点或者父结点就跳过
dfs2(j, u, false); // 继续计算轻子结点的答案且不保留
}
if (son[u]) dfs2(son[u], u, true); // 计算重儿子答案并保留计算过程中的数据(用于继承)
for (int i = 0; i < g[u].size(); i ++ )
{
int j = g[u][i]; // 子结点编号
if (j == fa || j == son[u]) continue; // 遇到重子结点或者父结点就跳过
// 子树结点的 DFS 序构成一段连续区间,可以直接遍历
for (int k = L[j]; k <= R[j]; k ++ ) add(id[k]); // 加上轻子结点对答案的贡献
}
add(u); // 加上当前子树根结点对答案的贡献
ans[u] = totcol;
if (keep == false) // 如果当前计算的答案不保留 就删去
for (int i = L[u]; i <= R[u]; i ++ ) del(id[i]);
}