将整棵树剖分为若干条不相交的链,使它组合成线性结构,然后用其他的数据结构(树状数组或线段树等等)维护信息。
我们以它对树上两点之间路径上权值的最大值的查询和修改为例,来介绍树链剖分。首先有如下几个概念。
定义 重子节点 表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。
定义 轻子节点 表示剩余的所有子结点。
从这个结点到重子节点的边为 重边 。
到其他轻子节点的边为 轻边 。
若干条首尾衔接的重边构成 重链 。
树剖的实现分两个 DFS 的过程。
第一个 DFS 记录每个结点的父节点(fa)、深度(d)、子树大小(size)、重子节点(son)。
代码如下:
void dfs1(int x, int depth) {
d[x] = depth;
size[x] = 1;
for (int i = head[x]; i; i = nex[i]) {
int y = to[i];
if (!d[y]) {
fa[y] = x;
dfs1(y, depth + 1);
size[x] += size[y];
if (!son[x] || size[y] > size[son[x]]) {
son[x] = y;
}
}
}
}
这个相对简单
第二个 dfs 记录所在链的链顶(top)、重边优先遍历时的 dfs 序(dfn)、dfs 序对应的节点编号(rk)。其实 r k rk rk和 d f n dfn dfn就是互为反函数,一个负责将树上的点映射到线性结构中,一个负责将线性结构中的点映射到树上。
代码如下:
void dfs2(int x, int tp) {
dfn[x] = ++num;
rk[num] = x;
top[x] = tp;
if (!son[x])return;
dfs2(x, tp);
for (int i = head[x]; i; i = nex[i]) {
int y = to[i];
if (y != son[x] && y != fa[x])dfs2(y, y);
}
}
求解过程中,我们求dfs序的时候,优先访问重儿子(可以保证重链上的时间戳连续)。求解重链的链头的时候,如果是在同一条重链上,那么链头不变,不在同一条重链上的话,如果它还有儿子,那么它就是它的这个儿子就是一个轻儿子(显而易见),并且还是如果这个儿子还有儿子,那么它也是一个重链的链头。所以后面如果不是重儿子且不是回边,就将后续节点所在的重链的链头设成了这个轻儿子。
然后就是用线段树维护 d f s dfs dfs序了,可以这样理解,通过求树的 d f s dfs dfs序,将树形结构线性化了,因此可以用线段树维护,每次对树上的区间操作的时候,分两种情况,两个点在一个重链上,这个简单,一条重链上的点的 d f s dfs dfs序一定都是连续的,直接在线段树中区间查询修改即可,第二种就是不在同一链上,我们就将链头较深的点上移,上移过程实际就是移动到当前点所在重链的链头,还是走了一个连续的 d f s dfs dfs序,我们可以在这个小区间查询 / 修改,然后移动到链头后再向它的父节点跳一下(因为刚刚并没有跳出这个重链),直到两个点在同一条重链上,就成了第一种情况了。
树链剖分还可以高效的求解LCA,学到了学到了
过程也不复杂,简单说说,首先两个点到达了LCA就是上升到x == y的时候了,此时两个点就在同一条重链上,因此还是先让两个点走到同一条重链上,然后这两个点深度小的就是LCA了。
半成的代码
// 树链剖分
#include <bits/stdc++.h>
using namespace std;
const int N = 310;
int head[N], to[N], nex[N], cnt;
int son[N];// 重儿子
int fa[N];// 父节点
int d[N];// 深度
int dfn[N];// dfs序
int size[N];// 子树大小
int rk[N];// dfs序对应节点编号,和dfn互为反函数
int top[N];// 所在链的链顶 初始都是自己,因为可能会有一点没有子节点且是轻儿子的情况
//此时它不在重链上,向上返回直接回到自己的父节点即可
int w[N];
int num;
void add(int a, int b) {
++cnt;
to[cnt] = b;
nex[cnt] = head[a];
head[a] = cnt;
}
void dfs1(int x, int depth) {
d[x] = depth;
size[x] = 1;
for (int i = head[x]; i; i = nex[i]) {
int y = to[i];
if (!d[y]) {
fa[y] = x;
dfs1(y, depth + 1);
size[x] += size[y];
if (!son[x] || size[y] > size[son[x]]) {
son[x] = y;
}
}
}
}
void dfs2(int x, int tp) {
dfn[x] = ++num;
rk[num] = x;
top[x] = tp;
if (!son[x])return;
dfs2(x, tp);
for (int i = head[x]; i; i = nex[i]) {
int y = to[i];
if (y != son[x] && y != fa[x])dfs2(y, y);
}
}
struct p {
int l, r, max, lazy;
};
struct SegementTree {
p c[N * 4];
void build(int l, int r, int k) {
c[k].l = l;
c[k].r = r;
if (l == r) {
c[k].max = w[rk[l]];
return;
}
int mid = (l + r) >> 1;
build(l, mid, k << 1);
build(mid + 1, r, k >> 1 | 1);
c[k].max = max(c[k << 1].max, c[k << 1 | 1].max);
}
int query(int l, int r, int k) {
if (l <= c[k].l && r >= c[k].r) {
return c[k].max;
}
if (l > c[k].r || c[k].l > r)return 0;
int ans = 0;
int mid = (c[k].l + c[k].r) >> 1;
if (l <= mid)ans += query(l, r, k << 1);
if (r > mid)ans += query(l, r, k << 1 | 1);
return ans;
}
int getAns(int x, int y) {
// 查询 x 到 y 路径上的最大值
int ans = 0;
while (top[x] != top[y]) {
if (d[top[x]] < d[top[y]]) {
swap(x, y);
}
ans += query(dfn[top[x]], dfn[x], 1);
x = top[x];
x = fa[x];
}
if (d[x] < d[y]) {
swap(x, y);
}
ans += query(dfn[y], dfn[x], 1);
return ans;
}
void down(int k) {
if (!c[k].lazy)return;
c[k << 1].max += c[k].lazy;
c[k << 1 | 1].max += c[k].lazy;
c[k << 1].lazy += c[k].lazy;
c[k << 1 | 1].lazy += c[k].lazy;
c[k].lazy = 0;
}
void modify(int l, int r, int k, int z) {
if (l > c[k].r || c[k].l > r)return;
if (l <= c[k].l && r >= c[k].r) {
c[k].max += z;
c[k].lazy += z;
}
down(k);
int mid = (c[k].l + c[k].r) >> 1;
if (l <= mid)modify(l, r, k << 1, z);
if (r > mid)modify(l, r, k << 1 | 1, z);
c[k].max = max(c[k << 1].max, c[k << 1 | 1].max);
}
void change(int x, int y, int z) {
while (top[x] != top[y]) {
if (d[top[x] < d[top[y]]]) {
swap(x, y);
}
modify(dfn[top[x]], dfn[x], 1, z);
x = fa[top[x]];
}
if (d[x] < d[y]) {
swap(x, y);
}
modify(dfn[y], dfn[x], 1, z);
}
};
SegementTree st;
int main()
{
// 数据输入
// 假设有个 n 吧 hhh
for (int i = 1; i <= n; i++){
top[i] = i;
}
int root;
dfs1(root, root);
dfs2(root, root);
st.build(1, num, 1);
int m;
cin >> m;// 询问
while (m--) {
// 以查询为例
int x, y;
cin >> x >> y;
cout << st.getAns(x, y) << "\n";
}
return 0;
}