【数据结构】树链剖分 (图文代码详解)

一、前言

在学习树链剖分之前,我们需要了解树链剖分是用来解决一个什么样的问题的算法。在平时的写题中,我们有时候会遇到有关线段的区间修改,求和问题,在这类问题中,我们往往会使用线段树来处理。但是如果我们需要在一颗树上就行区间操作和查询呢?这种时候我们该怎么处理呢?

这个时候我们想一想,如果我们能够把一颗树变成一条链的话,那不就可以用线段树来就行操作了吗?没错,这就是树链剖分的的作用,将一棵树转化为一条链,然后建立线段树,具体的方法见后文介绍。

二、前置知识点

1.链式前向星建图
2.线段树区间操作与查询+lazy标记
3.树的重链
4.LCA(最近公共祖先节点)

三、预处理

首先,题目中将会给出一颗树,而我们将要把这棵树按照重链拆开成为一条链。具体拆分效果如图:
在这里插入图片描述
在图中的树中,相同颜色的点将被分到同一条链中。划分依据为根据重链划分,重链的概念限于篇幅不在本篇博客中提及。然后这一颗树将会被重组为这样一条链。
在这里插入图片描述
这样用什么好处呢?我们发现对于每一个节点,以该节点为根的子树的节点在链中都是一个连续的区间(这里后面要考)。这就让我们非常方便的去遍历子树。相同颜色链的第一个节点就是重链的头结点。

在成功构造出链表后,我们需要根据DFS序列来为节点重新编号,其目的是为了方便在之后用线段树来处理和维护。

第一次DFS

通过第一次DFS,我们需要求出每一个节点的四个值:

  • 该节点的父亲节点f[i]
  • 该点的深度dep[i]
  • 以该点为根的子树的节点数量sz[i]
  • 该节点的重儿子son[i]

具体代码如下(比较简单所以思路就不多介绍了):

void dfs1(int u, int f, int d)
{
  dep[u] = d, fa[u] = f, sz[u] = 1;
  for (int i = h[u]; ~i; i = ne[i])	//链式前向星遍历边
  {
    int j = e[i];
    if (j == f)
      continue;
    dfs1(j, u, d + 1);  //dfs递归遍历
    sz[u] += sz[j];
    if (sz[son[u]] < sz[j])  //子树节点数最多的子节点为重儿子
      son[u] = j;
  }
}
第二次DFS

通过第二次DFS,我们需要求出以下几个数据

  • 每一个节点的新id,id[i]
  • 每一个节点所在重链的链首节点 top[i]
  • 新节点对应点的值 nw[i]

具体代码如下:

//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
  id[u] = ++idx, nw[idx] = a[u], top[u] = t;
  if (!son[u])    //这里如果没有重儿子则代表已经是叶子节点了
    return;
  dfs2(son[u], t);   //重儿子重链顶点不变
  for (int i = h[u]; ~i; i = ne[i])
  {
    int j = e[i];
    if (j == fa[u] || j == son[u])
      continue;
    dfs2(j, j);    //其余点的重链顶点就是自己
  }
}

预处理完后的效果图如下,可以结合上文内容理解
在这里插入图片描述

三、线段树部分

线段树基础部分

在成功的化树为链之后,我们就可以先把线段树的模板敲出来了,记得要带上区间修改和lazy标记。本篇博客重点不在此就不过多的介绍了。

ps:数据结构题代码真的写的好累…

struct node
{
  int l, r, sum, lazy;
} tr[N << 2];

/*线段树的部分*/
void pushup(int u)
{
  tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
  auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
  if (root.lazy)
  {
    left.sum += root.lazy * (left.r - left.l + 1);
    left.lazy += root.lazy;
    right.sum += root.lazy * (right.r - right.l + 1);
    right.lazy += root.lazy;
    root.lazy = 0;
  }
}

void build(int u, int l, int r)
{
  tr[u] = { l, r, nw[r], 0 };
  if (l == r)
    return;
  int mid = (l + r) >> 1;
  build(u << 1, l, mid);
  build(u << 1 | 1, mid + 1, r);
  pushup(u);
}

void update(int u, int l, int r, int k)
{
  if (l <= tr[u].l && r >= tr[u].r)
  {
    tr[u].lazy += k;
    tr[u].sum += k * (tr[u].r - tr[u].l + 1);
    return;
  }
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  if (l <= mid)
    update(u << 1, l, r, k);
  if (mid < r)
    update(u << 1 | 1, l, r, k);
  pushup(u);
}

int query(int u, int l, int r)
{
  if (l <= tr[u].l && r >= tr[u].r)
    return tr[u].sum;
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  int res = 0;
  if (l <= mid)
    res += query(u << 1, l, r);
  if (mid < r)
    res += query(u << 1 | 1, l, r);
  return res;
}
树链剖分部分
1.树上路径修改与查询

在区间修改时,先把两个点往上跳到同一个点上。只要两个点不在一条重链上,就不断的往上跳,然后将重链中的所有结点用线段树来进行修改。查询时同理,具体见代码中的注释。

void update_path(int u, int v, int k)    //将树上u到v的路径上的点全部加上k
{
  while (top[u] != top[v])    //当两个点不在一条重链上时
  {
    // u是深度高的重链
    if (dep[top[u]] < dep[top[v]])  //深度高的点往上跳
      swap(u, v);
    update(1, id[top[u]], id[u], k);   //线段树更新这一条重链
    u = fa[top[u]];   //跳上另外一条重链上
  }
  //跳到一条重链上后就可以直接更新这个区间了
  if (dep[u] < dep[v])
    swap(u, v);
  update(1, id[v], id[u], k);
}

//查询操作与更新操作同理
int query_path(int u, int v)
{
  int res = 0;
  while (top[u] != top[v])
  {
    if (dep[top[u]] < dep[top[v]])
      swap(u, v);
    res += query(1, id[top[u]], id[u]);
    u = fa[top[u]];
  }
  if (dep[u] < dep[v])
    swap(u, v);
  res += query(1, id[v], id[u]);
  return res;
}

举个栗子:就在上文例子中的树中节点6和节点8全部加上1
step1:建立两个节点的指针
在这里插入图片描述
step2:top[5] != top[6],dep[top[6]] < dep[top[5]],橙色指针往上跳,区间修改id[top[5]] ~ id[5]
在这里插入图片描述
step3:top[2] != top[6],dep[top[2]] < dep[top[6]],黄色指针往上跳,区间修改id[top[6]] ~ id[6]
在这里插入图片描述
step4:top[2] == top[1],此时两个指针在同一条重链上了,区间修改id[1] ~ id[2]
在这里插入图片描述
查询同理

2.树上子树修改与查询

子树的修改与查询比起路径上就要简单多了,因为上文提到过我们在化树为链后形成的链后有一个特性,即以该节点为根的子树的节点在链中都是一个连续的区间(这里可以结合上文的图片进行理解),那么以u为根节点的子树的所有结点区间就是 id[u]~ id[u] + sz[u] - 1

void update_tree(int u, int k)
{
  update(1, id[u], id[u] + sz[u] - 1, k);
}

int query_tree(int u)
{
  return query(1, id[u], id[u] + sz[u] - 1);
}

四、LCA求法

在经过两次DFS处理过后,我们就已经可以通过logn的时间复杂度来求出两个节点的最近公共祖先LCA,怎么求呢?其实在上文的区间修改的时候就已经讲完了。

当两个点不属于一条重链时,我们可以通过指针的不断的往上跳跃来让两个指针达到同一条重链。然后当我们发现,当两个指针跳到同一条重链上后,“上面”的点就是两个点的最近公共祖先。

例如,LCA(5 , 6) = 1。具体求法可以看看上方的区间修改的例子。

五、代码模板

int n, m;
int a[N];
int h[N], e[N], ne[N], cnt;
//树链剖分后编号,权值,众联顶点编号,父节点,深度,子树结点数量,重儿子编号
int id[N], nw[N], top[N], fa[N], dep[N], sz[N], son[N], idx;

struct node
{
  int l, r, sum, lazy;
} tr[N << 2];

void add(int a, int b)
{
  ne[cnt] = h[a];
  h[a] = cnt;
  e[cnt++] = b;
}

/*树链剖分预处理*/

//预处理找出重儿子,深度
void dfs1(int u, int f, int d)
{
  dep[u] = d, fa[u] = f, sz[u] = 1;
  for (int i = h[u]; ~i; i = ne[i])
  {
    int j = e[i];
    if (j == f)
      continue;
    dfs1(j, u, d + 1);
    sz[u] += sz[j];
    if (sz[son[u]] < sz[j])
      son[u] = j;
  }
}

//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
  id[u] = ++idx, nw[idx] = a[u], top[u] = t;
  if (!son[u])
    return;
  dfs2(son[u], t);
  for (int i = h[u]; ~i; i = ne[i])
  {
    int j = e[i];
    if (j == fa[u] || j == son[u])
      continue;
    dfs2(j, j);
  }
}

/*线段树的部分*/
void pushup(int u)
{
  tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
  auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
  if (root.lazy)
  {
    left.sum += root.lazy * (left.r - left.l + 1);
    left.lazy += root.lazy;
    right.sum += root.lazy * (right.r - right.l + 1);
    right.lazy += root.lazy;
    root.lazy = 0;
  }
}

void build(int u, int l, int r)
{
  tr[u] = { l, r, nw[r], 0 };
  if (l == r)
    return;
  int mid = (l + r) >> 1;
  build(u << 1, l, mid);
  build(u << 1 | 1, mid + 1, r);
  pushup(u);
}

void update(int u, int l, int r, int k)
{
  if (l <= tr[u].l && r >= tr[u].r)
  {
    tr[u].lazy += k;
    tr[u].sum += k * (tr[u].r - tr[u].l + 1);
    return;
  }
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  if (l <= mid)
    update(u << 1, l, r, k);
  if (mid < r)
    update(u << 1 | 1, l, r, k);
  pushup(u);
}

int query(int u, int l, int r)
{
  if (l <= tr[u].l && r >= tr[u].r)
    return tr[u].sum;
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  int res = 0;
  if (l <= mid)
    res += query(u << 1, l, r);
  if (mid < r)
    res += query(u << 1 | 1, l, r);
  return res;
}

/*树链剖分部分*/

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]])
            swap(x,y);
        x=fa[top[x]];
    }
    return dep[x]<dep[y]?x:y;
}


void update_path(int u, int v, int k)
{
  while (top[u] != top[v])
  {
    // u是深度高的重链
    if (dep[top[u]] < dep[top[v]])
      swap(u, v);
    update(1, id[top[u]], id[u], k);
    u = fa[top[u]];
  }
  if (dep[u] < dep[v])
    swap(u, v);
  update(1, id[v], id[u], k);
}

int query_path(int u, int v)
{
  int res = 0;
  while (top[u] != top[v])
  {
    if (dep[top[u]] < dep[top[v]])
      swap(u, v);
    res += query(1, id[top[u]], id[u]);
    u = fa[top[u]];
  }
  if (dep[u] < dep[v])
    swap(u, v);
  res += query(1, id[v], id[u]);
  return res;
}

void update_tree(int u, int k)
{
  update(1, id[u], id[u] + sz[u] - 1, k);
}

int query_tree(int u)
{
  return query(1, id[u], id[u] + sz[u] - 1);
}

六、例题

Acwing 树链剖分

给定一棵树,树中包含 n 个节点(编号 1∼n),其中第 i 个节点的权值为 ai。

初始时,1 号节点为树的根节点。

现在要对该树进行 m 次操作,操作分为以下 4 种类型:

1 u v k,修改路径上节点权值,将节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值增加 k。
2 u k,修改子树上节点权值,将以节点 u 为根的子树上的所有节点的权值增加 k。
3 u v,询问路径,询问节点 u 和节点 v 之间路径上的所有节点(包括这两个节点)的权值和。
4 u,询问子树,询问以节点 u 为根的子树上的所有节点的权值和。
输入格式
第一行包含一个整数 n,表示节点个数。

第二行包含 n 个整数,其中第 i 个整数表示 ai。

接下来 n−1 行,每行包含两个整数 x,y,表示节点 x 和节点 y 之间存在一条边。

再一行包含一个整数 m,表示操作次数。

接下来 m 行,每行包含一个操作,格式如题目所述。

输出格式
对于每个操作 3 和操作 4,输出一行一个整数表示答案。

数据范围
1≤n,m≤1e5,
0≤ai,k≤1e5,
1≤u,v,x,y≤n
输入样例:
5
1 3 7 4 5
1 3
1 4
1 5
2 3
5
1 3 4 3
3 5 4
1 3 5 10
2 3 5
4 1
输出样例:
16
69

AC代码

#include <bits/stdc++.h>
#define endl '\n'
#define el endl
#define pb push_back
#define int long long
#define INF 0x3f3f3f3f
#define ull unsigned long long
#define with << ' ' <<
#define print(x) cout << (x) << endl
#define all(x) (x).begin(), (x).end()
#define mem(a, b) memset(a, b, sizeof(a))
#define f(i, l, r) for (int i = (l); i <= (r); i++)
#define ff(i, l, r) for (int i = (l); i >= (r); i--)
#define pr(x, n) f(_, 1, n) cout << (x[_]) << " \n"[_ == n];
#define ck(x) cerr << #x << "=" << x << '\n';
using namespace std;
typedef pair<int, int> PII;
const int N = 1e6 + 7, mod = 1e9 + 7;

int n, m;
int a[N];
int h[N], e[N], ne[N], cnt;

//树链剖分后编号,权值,众联顶点编号,父节点,深度,子树结点数量,重儿子编号
int id[N], nw[N], top[N], fa[N], dep[N], sz[N], son[N], idx;

struct node
{
  int l, r, sum, lazy;
} tr[N << 2];

void add(int a, int b)
{
  ne[cnt] = h[a];
  h[a] = cnt;
  e[cnt++] = b;
}

/*树链剖分预处理*/

//预处理找出重儿子,深度
void dfs1(int u, int f, int d)
{
  dep[u] = d, fa[u] = f, sz[u] = 1;
  for (int i = h[u]; ~i; i = ne[i])
  {
    int j = e[i];
    if (j == f)
      continue;
    dfs1(j, u, d + 1);
    sz[u] += sz[j];
    if (sz[son[u]] < sz[j])
      son[u] = j;
  }
}

//化树为链,t是重链的顶点
void dfs2(int u, int t)
{
  id[u] = ++idx, nw[idx] = a[u], top[u] = t;
  if (!son[u])
    return;
  dfs2(son[u], t);
  for (int i = h[u]; ~i; i = ne[i])
  {
    int j = e[i];
    if (j == fa[u] || j == son[u])
      continue;
    dfs2(j, j);
  }
}

/*线段树的部分*/
void pushup(int u)
{
  tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u)
{
  auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
  if (root.lazy)
  {
    left.sum += root.lazy * (left.r - left.l + 1);
    left.lazy += root.lazy;
    right.sum += root.lazy * (right.r - right.l + 1);
    right.lazy += root.lazy;
    root.lazy = 0;
  }
}

void build(int u, int l, int r)
{
  tr[u] = { l, r, nw[r], 0 };
  if (l == r)
    return;
  int mid = (l + r) >> 1;
  build(u << 1, l, mid);
  build(u << 1 | 1, mid + 1, r);
  pushup(u);
}

void update(int u, int l, int r, int k)
{
  if (l <= tr[u].l && r >= tr[u].r)
  {
    tr[u].lazy += k;
    tr[u].sum += k * (tr[u].r - tr[u].l + 1);
    return;
  }
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  if (l <= mid)
    update(u << 1, l, r, k);
  if (mid < r)
    update(u << 1 | 1, l, r, k);
  pushup(u);
}



int query(int u, int l, int r)
{
  if (l <= tr[u].l && r >= tr[u].r)
    return tr[u].sum;
  pushdown(u);
  int mid = (tr[u].l + tr[u].r) >> 1;
  int res = 0;
  if (l <= mid)
    res += query(u << 1, l, r);
  if (mid < r)
    res += query(u << 1 | 1, l, r);
  return res;
}



/*树链剖分部分*/

void update_path(int u, int v, int k)
{
  while (top[u] != top[v])
  {
    // u是深度高的重链
    if (dep[top[u]] < dep[top[v]])
      swap(u, v);
    update(1, id[top[u]], id[u], k);
    u = fa[top[u]];
  }
  if (dep[u] < dep[v])
    swap(u, v);
  update(1, id[v], id[u], k);
}

int query_path(int u, int v)
{
  int res = 0;
  while (top[u] != top[v])
  {
    if (dep[top[u]] < dep[top[v]])
      swap(u, v);
    res += query(1, id[top[u]], id[u]);
    u = fa[top[u]];
  }
  if (dep[u] < dep[v])
    swap(u, v);
  res += query(1, id[v], id[u]);
  return res;
}


void update_tree(int u, int k)
{
  update(1, id[u], id[u] + sz[u] - 1, k);
}

int query_tree(int u)
{
  return query(1, id[u], id[u] + sz[u] - 1);
}


void solve()
{
	mem(h,-1);
  cin >> n;
  for (int i = 1; i <= n; i++)
    cin >> a[i];
  for (int i = 1; i < n; i++)
  {
    int a, b;
    cin >> a >> b;
    add(a, b), add(b, a);
  }
  dfs1(1,-1,1);
  dfs2(1,1);
  build(1,1,n);
  cin >> m;
  while (m--)
  {
    int op, u, v, k;
    cin >> op;
    if (op == 1)
    {
      cin >> u >> v >> k;
      update_path(u, v, k);
    }
    else if (op == 2)
    {
      cin >> u >> k;
      update_tree(u, k);
    }
    else if (op == 3)
    {
      cin >> u >> v;
      print(query_path(u, v));
    }
    else
    {
      cin >> u;
      print(query_tree(u));
    }
  }
}

signed main()
{
  std::ios::sync_with_stdio(false);
  cin.tie(0), cout.tie(0);
  // clock_t start_time = clock();
  int __ = 1;
  //    cin>>__;
  //    init();
  while (__--)
    solve();
  // clock_t end_time = clock();
  // cerr << "Running time is: " << ( double )(end_time - start_time) / CLOCKS_PER_SEC * 1000 << "ms" << endl;
  return 0;
}

作者:Avalon·Demerzel
更多内容见专栏:图论与数据结构

  • 6
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值