一、前言
在学习树链剖分之前,我们需要了解树链剖分是用来解决一个什么样的问题的算法。在平时的写题中,我们有时候会遇到有关线段的区间修改,求和问题,在这类问题中,我们往往会使用线段树来处理。但是如果我们需要在一颗树上就行区间操作和查询呢?这种时候我们该怎么处理呢?
这个时候我们想一想,如果我们能够把一颗树变成一条链的话,那不就可以用线段树来就行操作了吗?没错,这就是树链剖分的的作用,将一棵树转化为一条链,然后建立线段树,具体的方法见后文介绍。
二、前置知识点
1.链式前向星建图
2.线段树区间操作与查询+lazy标记
3.树的重链
4.LCA(最近公共祖先节点)
三、预处理
首先,题目中将会给出一颗树,而我们将要把这棵树按照重链拆开成为一条链。具体拆分效果如图:
在图中的树中,相同颜色的点将被分到同一条链中。划分依据为根据重链划分,重链的概念限于篇幅不在本篇博客中提及。然后这一颗树将会被重组为这样一条链。
这样用什么好处呢?我们发现对于每一个节点,以该节点为根的子树的节点在链中都是一个连续的区间(这里后面要考)。这就让我们非常方便的去遍历子树。相同颜色链的第一个节点就是重链的头结点。
在成功构造出链表后,我们需要根据DFS序列来为节点重新编号,其目的是为了方便在之后用线段树来处理和维护。
第一次DFS
通过第一次DFS,我们需要求出每一个节点的四个值:
- 该节点的父亲节点f[i];
- 该点的深度dep[i];
- 以该点为根的子树的节点数量sz[i];
- 该节点的重儿子son[i];
具体代码如下(比较简单所以思路就不多介绍了):
void dfs1(int u, int f, int d)
{
dep[u] = d, fa[u] = f, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) //链式前向星遍历边
{
int j = e[i];
if (j == f)
continue;
dfs1(j, u, d + 1); //dfs递归遍历
sz[u] += sz[j];
if (sz[son[u]] < sz[j]) //子树节点数最多的子节点为重儿子
son[u] = j;
}
}
第二次DFS
通过第二次DFS,我们需要求出以下几个数据
- 每一个节点的新id,id[i]
- 每一个节点所在重链的链首节点 top[i]
- 新节点对应点的值 nw[i]
具体代码如下:
//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
id[u] = ++idx, nw[idx] = a[u], top[u] = t;
if (!son[u]) //这里如果没有重儿子则代表已经是叶子节点了
return;
dfs2(son[u], t); //重儿子重链顶点不变
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u])
continue;
dfs2(j, j); //其余点的重链顶点就是自己
}
}
预处理完后的效果图如下,可以结合上文内容理解
三、线段树部分
线段树基础部分
在成功的化树为链之后,我们就可以先把线段树的模板敲出来了,记得要带上区间修改和lazy标记。本篇博客重点不在此就不过多的介绍了。
ps:数据结构题代码真的写的好累…
struct node
{
int l, r, sum, lazy;
} tr[N << 2];
/*线段树的部分*/
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.lazy)
{
left.sum += root.lazy * (left.r - left.l + 1);
left.lazy += root.lazy;
right.sum += root.lazy * (right.r - right.l + 1);
right.lazy += root.lazy;
root.lazy = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = { l, r, nw[r], 0 };
if (l == r)
return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k)
{
if (l <= tr[u].l && r >= tr[u].r)
{
tr[u].lazy += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid)
update(u << 1, l, r, k);
if (mid < r)
update(u << 1 | 1, l, r, k);
pushup(u);
}
int query(int u, int l, int r)
{
if (l <= tr[u].l && r >= tr[u].r)
return tr[u].sum;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
int res = 0;
if (l <= mid)
res += query(u << 1, l, r);
if (mid < r)
res += query(u << 1 | 1, l, r);
return res;
}
树链剖分部分
1.树上路径修改与查询
在区间修改时,先把两个点往上跳到同一个点上。只要两个点不在一条重链上,就不断的往上跳,然后将重链中的所有结点用线段树来进行修改。查询时同理,具体见代码中的注释。
void update_path(int u, int v, int k) //将树上u到v的路径上的点全部加上k
{
while (top[u] != top[v]) //当两个点不在一条重链上时
{
// u是深度高的重链
if (dep[top[u]] < dep[top[v]]) //深度高的点往上跳
swap(u, v);
update(1, id[top[u]], id[u], k); //线段树更新这一条重链
u = fa[top[u]]; //跳上另外一条重链上
}
//跳到一条重链上后就可以直接更新这个区间了
if (dep[u] < dep[v])
swap(u, v);
update(1, id[v], id[u], k);
}
//查询操作与更新操作同理
int query_path(int u, int v)
{
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
res += query(1, id[v], id[u]);
return res;
}
举个栗子:就在上文例子中的树中节点6和节点8全部加上1
step1:建立两个节点的指针
step2:top[5] != top[6],dep[top[6]] < dep[top[5]],橙色指针往上跳,区间修改id[top[5]] ~ id[5]
step3:top[2] != top[6],dep[top[2]] < dep[top[6]],黄色指针往上跳,区间修改id[top[6]] ~ id[6]
step4:top[2] == top[1],此时两个指针在同一条重链上了,区间修改id[1] ~ id[2]
查询同理
2.树上子树修改与查询
子树的修改与查询比起路径上就要简单多了,因为上文提到过我们在化树为链后形成的链后有一个特性,即以该节点为根的子树的节点在链中都是一个连续的区间(这里可以结合上文的图片进行理解),那么以u为根节点的子树的所有结点区间就是 id[u]~ id[u] + sz[u] - 1
void update_tree(int u, int k)
{
update(1, id[u], id[u] + sz[u] - 1, k);
}
int query_tree(int u)
{
return query(1, id[u], id[u] + sz[u] - 1);
}
四、LCA求法
在经过两次DFS处理过后,我们就已经可以通过logn的时间复杂度来求出两个节点的最近公共祖先LCA,怎么求呢?其实在上文的区间修改的时候就已经讲完了。
当两个点不属于一条重链时,我们可以通过指针的不断的往上跳跃来让两个指针达到同一条重链。然后当我们发现,当两个指针跳到同一条重链上后,“上面”的点就是两个点的最近公共祖先。
例如,LCA(5 , 6) = 1。具体求法可以看看上方的区间修改的例子。
五、代码模板
int n, m;
int a[N];
int h[N], e[N], ne[N], cnt;
//树链剖分后编号,权值,众联顶点编号,父节点,深度,子树结点数量,重儿子编号
int id[N], nw[N], top[N], fa[N], dep[N], sz[N], son[N], idx;
struct node
{
int l, r, sum, lazy;
} tr[N << 2];
void add(int a, int b)
{
ne[cnt] = h[a];
h[a] = cnt;
e[cnt++] = b;
}
/*树链剖分预处理*/
//预处理找出重儿子,深度
void dfs1(int u, int f, int d)
{
dep[u] = d, fa[u] = f, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == f)
continue;
dfs1(j, u, d + 1);
sz[u] += sz[j];
if (sz[son[u]] < sz[j])
son[u] = j;
}
}
//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
id[u] = ++idx, nw[idx] = a[u], top[u] = t;
if (!son[u])
return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u])
continue;
dfs2(j, j);
}
}
/*线段树的部分*/
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.lazy)
{
left.sum += root.lazy * (left.r - left.l + 1);
left.lazy += root.lazy;
right.sum += root.lazy * (right.r - right.l + 1);
right.lazy += root.lazy;
root.lazy = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = { l, r, nw[r], 0 };
if (l == r)
return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k)
{
if (l <= tr[u].l && r >= tr[u].r)
{
tr[u].lazy += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid)
update(u << 1, l, r, k);
if (mid < r)
update(u << 1 | 1, l, r, k);
pushup(u);
}
int query(int u, int l, int r)
{
if (l <= tr[u].l && r >= tr[u].r)
return tr[u].sum;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
int res = 0;
if (l <= mid)
res += query(u << 1, l, r);
if (mid < r)
res += query(u << 1 | 1, l, r);
return res;
}
/*树链剖分部分*/
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])
swap(x,y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
void update_path(int u, int v, int k)
{
while (top[u] != top[v])
{
// u是深度高的重链
if (dep[top[u]] < dep[top[v]])
swap(u, v);
update(1, id[top[u]], id[u], k);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
update(1, id[v], id[u], k);
}
int query_path(int u, int v)
{
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
res += query(1, id[v], id[u]);
return res;
}
void update_tree(int u, int k)
{
update(1, id[u], id[u] + sz[u] - 1, k);
}
int query_tree(int u)
{
return query(1, id[u], id[u] + sz[u] - 1);
}
六、例题
给定一棵树,树中包含 n 个节点(编号 1∼n),其中第 i 个节点的权值为 ai。
初始时,1 号节点为树的根节点。
现在要对该树进行 m 次操作,操作分为以下 4 种类型:
1 u v k,修改路径上节点权值,将节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值增加 k。
2 u k,修改子树上节点权值,将以节点 u 为根的子树上的所有节点的权值增加 k。
3 u v,询问路径,询问节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值和。
4 u,询问子树,询问以节点 u 为根的子树上的所有节点的权值和。
输入格式
第一行包含一个整数 n,表示节点个数。
第二行包含 n 个整数,其中第 i 个整数表示 ai。
接下来 n−1 行,每行包含两个整数 x,y,表示节点 x 和节点 y 之间存在一条边。
再一行包含一个整数 m,表示操作次数。
接下来 m 行,每行包含一个操作,格式如题目所述。
输出格式
对于每个操作 3 和操作 4,输出一行一个整数表示答案。
数据范围
1≤n,m≤1e5,
0≤ai,k≤1e5,
1≤u,v,x,y≤n
输入样例:
5
1 3 7 4 5
1 3
1 4
1 5
2 3
5
1 3 4 3
3 5 4
1 3 5 10
2 3 5
4 1
输出样例:
16
69
AC代码
#include <bits/stdc++.h>
#define endl '\n'
#define el endl
#define pb push_back
#define int long long
#define INF 0x3f3f3f3f
#define ull unsigned long long
#define with << ' ' <<
#define print(x) cout << (x) << endl
#define all(x) (x).begin(), (x).end()
#define mem(a, b) memset(a, b, sizeof(a))
#define f(i, l, r) for (int i = (l); i <= (r); i++)
#define ff(i, l, r) for (int i = (l); i >= (r); i--)
#define pr(x, n) f(_, 1, n) cout << (x[_]) << " \n"[_ == n];
#define ck(x) cerr << #x << "=" << x << '\n';
using namespace std;
typedef pair<int, int> PII;
const int N = 1e6 + 7, mod = 1e9 + 7;
int n, m;
int a[N];
int h[N], e[N], ne[N], cnt;
//树链剖分后编号,权值,众联顶点编号,父节点,深度,子树结点数量,重儿子编号
int id[N], nw[N], top[N], fa[N], dep[N], sz[N], son[N], idx;
struct node
{
int l, r, sum, lazy;
} tr[N << 2];
void add(int a, int b)
{
ne[cnt] = h[a];
h[a] = cnt;
e[cnt++] = b;
}
/*树链剖分预处理*/
//预处理找出重儿子,深度
void dfs1(int u, int f, int d)
{
dep[u] = d, fa[u] = f, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == f)
continue;
dfs1(j, u, d + 1);
sz[u] += sz[j];
if (sz[son[u]] < sz[j])
son[u] = j;
}
}
//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
id[u] = ++idx, nw[idx] = a[u], top[u] = t;
if (!son[u])
return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j == fa[u] || j == son[u])
continue;
dfs2(j, j);
}
}
/*线段树的部分*/
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.lazy)
{
left.sum += root.lazy * (left.r - left.l + 1);
left.lazy += root.lazy;
right.sum += root.lazy * (right.r - right.l + 1);
right.lazy += root.lazy;
root.lazy = 0;
}
}
void build(int u, int l, int r)
{
tr[u] = { l, r, nw[r], 0 };
if (l == r)
return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void update(int u, int l, int r, int k)
{
if (l <= tr[u].l && r >= tr[u].r)
{
tr[u].lazy += k;
tr[u].sum += k * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
if (l <= mid)
update(u << 1, l, r, k);
if (mid < r)
update(u << 1 | 1, l, r, k);
pushup(u);
}
int query(int u, int l, int r)
{
if (l <= tr[u].l && r >= tr[u].r)
return tr[u].sum;
pushdown(u);
int mid = (tr[u].l + tr[u].r) >> 1;
int res = 0;
if (l <= mid)
res += query(u << 1, l, r);
if (mid < r)
res += query(u << 1 | 1, l, r);
return res;
}
/*树链剖分部分*/
void update_path(int u, int v, int k)
{
while (top[u] != top[v])
{
// u是深度高的重链
if (dep[top[u]] < dep[top[v]])
swap(u, v);
update(1, id[top[u]], id[u], k);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
update(1, id[v], id[u], k);
}
int query_path(int u, int v)
{
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] < dep[top[v]])
swap(u, v);
res += query(1, id[top[u]], id[u]);
u = fa[top[u]];
}
if (dep[u] < dep[v])
swap(u, v);
res += query(1, id[v], id[u]);
return res;
}
void update_tree(int u, int k)
{
update(1, id[u], id[u] + sz[u] - 1, k);
}
int query_tree(int u)
{
return query(1, id[u], id[u] + sz[u] - 1);
}
void solve()
{
mem(h,-1);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i < n; i++)
{
int a, b;
cin >> a >> b;
add(a, b), add(b, a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
cin >> m;
while (m--)
{
int op, u, v, k;
cin >> op;
if (op == 1)
{
cin >> u >> v >> k;
update_path(u, v, k);
}
else if (op == 2)
{
cin >> u >> k;
update_tree(u, k);
}
else if (op == 3)
{
cin >> u >> v;
print(query_path(u, v));
}
else
{
cin >> u;
print(query_tree(u));
}
}
}
signed main()
{
std::ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
// clock_t start_time = clock();
int __ = 1;
// cin>>__;
// init();
while (__--)
solve();
// clock_t end_time = clock();
// cerr << "Running time is: " << ( double )(end_time - start_time) / CLOCKS_PER_SEC * 1000 << "ms" << endl;
return 0;
}
作者:Avalon·Demerzel
更多内容见专栏:图论与数据结构