树链剖分
对于一颗树来说,维护树上任意一段路径信息比较困难,但是我们可以把树拆成多个链,这种处理和维护树上路径信息的方法叫树链剖分(树剖、链剖)。
树链剖分有多种形式,比如重链剖分、长链剖分和用于 Link/cut Tree 的剖分(有时被称作“实链剖分”),大多数情况下(没有特别说明时),“树链剖分”都指“重链剖分”。
重链剖分可以将树上的任意一条路径划分成不超过 O ( log n ) O(\log n) O(logn)条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点)。
重链剖分还能保证划分出的每条链上的节点 DFS 序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。
如:
- 求任意路径上的最大最小值,求和等等
- 修改任意路径上的所有点。
- 求LCA
等等路径&区间操作。
树链剖分的核心思想就是破树为链。
本文着重记录一下重链剖分。
概念
我们给出一些定义:
树的重量:一颗子树的所有节点数。
重子节点:根节点 r r r,有 n n n个子节点,分别为 s 1 , s 2 , … , s n s_{1},s_{2},\ldots,s_{n} s1,s2,…,sn,其中 r r r的重子节点有且仅有一个,这个子节点的重量比其他任何子节点的重量都大。特别的,如果 r r r为叶子节点,那么 r r r没有重子节点。
轻子节点:除了重子节点外其他节点都是轻子节点。
重边、轻边:从 r r r到其重子节点的边叫重边,其他边叫轻边。
重链:由连续干个重边组成的一条链叫做重链。特别的,一颗树中会出现一些叶子节点,没有与他们直接相连的重边,那他们就自己单独组成一个长度为 1 1 1 的重链。这样,一棵树就为被解剖成若干条重链。
下面给出节点的定义:
- f a ( x ) fa(x) fa(x)表示 x x x的父节点
- d e p ( x ) dep(x) dep(x)表示 x x x的深度
- s i z ( x ) siz(x) siz(x)表示 x x x的重量
- s o n ( x ) son(x) son(x)表示 x x x的重儿子
- t o p ( x ) top(x) top(x)表示 x x x所在的重链中节点深度最小的节点,即重链的头结点
- d f n ( x ) dfn(x) dfn(x),表示 x x x在DFS过程中的DFS序
- r n k ( d f s ) rnk(dfs) rnk(dfs),表示DFS序 d f s dfs dfs所对应的节点,有 x = r n k ( d f n ( x ) ) x = rnk(dfn(x)) x=rnk(dfn(x))
构造过程
构造一颗树的树链需要两个DFS过程。
定义数据结构:
struct Edge
{
int to;
int nxt;
} e[1000005];
int head[500005];
int tot = 0;
int fa[500005];
int siz[500005];
int dep[500005];
int son[500005];
int top[500005];
int dfn[500005];
int rnk[500005];
int ti = 0;
inline void add(int u, int v)
{
tot++;
e[tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
第一个DFS负责计算节点的父节点,深度,节点重量,重子节点。
void dfs1(int f, int idx)
{
fa[idx] = f;
dep[idx] = dep[f] + 1;
int si = 0;
for (int ne = head[idx]; ne != 0; ne = e[ne].nxt)
{
int su = e[ne].to;
if (su == f)
continue;
dfs1(idx, su);
si += siz[su];
if (siz[su] > siz[son[idx]])
son[idx] = su;
}
siz[idx] = 1 + si;
}
第二个DFS记录链头,DFS序,DFS序的逆映射。
注意,这里和普通DFS有一点不同,如果一个节点有重子节点,应该优先访问重子节点,这样才能保证树链上的节点DFS序莲须。
void dfs2(int f, int idx)
{
ti++;
dfn[idx] = ti;
rnk[ti] = idx;
int bs = son[idx];
if (bs != 0)
{
top[bs] = top[idx];
dfs2(idx, bs);
}
for (int ne = head[idx]; ne != 0; ne = e[ne].nxt)
{
int v = e[ne].to;
if (v == f || v == son[idx])
continue;
top[v] = v;
dfs2(idx, v);
}
}
分别调用:dfs1和dfs2就可以得到一颗树的树剖了。
求LCA
树链剖分可以用来求LCA,过程如下:
int LCA(int u, int v)
{
while (top[u] != top[v]) // 如果两个节点不在同一条树链上,那么就需要跳跃
{
// 跳跃从树链头节点深度较大的开始跳
if (dep[top[u]] > dep[top[v]])
u = fa[top[u]];
else
v = fa[top[v]];
}
// 如果处在同一个树链上,返回深度最小的那个节点,那个节点就是LCA
if (dep[u] < dep[v])
return u;
else
return v;
}
维护路程和,子树和
我们把dfn作为区间的下标,可以使用线段树或者树状数组维护区间和。
类比于LCA,我们只需要在求LCA的过程中,边走边加上路径和即可。
对于维护子树和,我们可以通过记录子树的最后的dfn,子树的dfn序列必定连续,维护区间和即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define FR freopen("in.txt", "r", stdin)
#define LT(x) (x << 1)
#define RT(x) ((x << 1) + 1)
struct TreeNode
{
int l;
int r;
ll sum;
ll lazy;
};
struct SegT
{
TreeNode t[400005];
ll mod;
void init(ll *init, int n, ll p)
{
mod = p;
buildTree(init, 1, 1, n + 1);
}
void buildTree(ll *init, int i, int l, int r)
{
t[i].l = l;
t[i].r = r;
t[i].lazy = 0;
if (l == r - 1)
{
t[i].sum = init[l] % mod;
}
else
{
int mid = (l + r) >> 1;
buildTree(init, LT(i), l, mid);
buildTree(init, RT(i), mid, r);
t[i].sum = (t[LT(i)].sum + t[RT(i)].sum) % mod;
}
}
void pushdown(int i)
{
t[LT(i)].sum = (t[LT(i)].sum + ((t[LT(i)].r - t[LT(i)].l) * t[i].lazy) % mod) % mod;
t[RT(i)].sum = (t[RT(i)].sum + ((t[RT(i)].r - t[RT(i)].l) * t[i].lazy % mod)) % mod;
t[LT(i)].lazy = (t[LT(i)].lazy + t[i].lazy) % mod;
t[RT(i)].lazy = (t[RT(i)].lazy + t[i].lazy) % mod;
t[i].lazy = 0;
}
void add(int i, int l, int r, ll val)
{
if (t[i].l >= l && t[i].r <= r)
{
t[i].sum = (t[i].sum + ((t[i].r - t[i].l) * val) % mod) % mod;
t[i].lazy = (t[i].lazy + val) % mod;
}
else
{
pushdown(i);
if (t[LT(i)].r > l)
{
add(LT(i), l, r, val);
}
if (t[RT(i)].l < r)
{
add(RT(i), l, r, val);
}
t[i].sum = (t[LT(i)].sum + t[RT(i)].sum) % mod;
}
}
ll query(int i, int l, int r)
{
if (t[i].l >= l && t[i].r <= r)
{
return t[i].sum % mod;
}
else
{
pushdown(i);
ll ans = 0;
if (t[LT(i)].r > l)
{
ans = (ans + query(LT(i), l, r)) % mod;
}
if (t[RT(i)].l < r)
{
ans = (ans + query(RT(i), l, r)) % mod;
}
return ans;
}
}
} ST;
struct Edge
{
int to;
int nxt;
} e[200005];
int head[100005];
int tot = 0;
int fa[100005];
int siz[100005];
int dep[100005];
int son[100005];
int top[100005];
int bottom[100005];
int dfn[100005];
int rnk[100005];
int ti = 0;
ll init[100005];
ll init1[100005];
inline void add(int u, int v)
{
tot++;
e[tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
void dfs1(int f, int idx)
{
fa[idx] = f;
dep[idx] = dep[f] + 1;
int si = 0;
for (int ne = head[idx]; ne != 0; ne = e[ne].nxt)
{
int su = e[ne].to;
if (su == f)
continue;
dfs1(idx, su);
si += siz[su];
if (siz[su] > siz[son[idx]])
son[idx] = su;
}
siz[idx] = 1 + si;
}
void dfs2(int f, int idx)
{
ti++;
dfn[idx] = ti;
rnk[ti] = idx;
bottom[idx] = ti;
int bs = son[idx];
if (bs != 0)
{
top[bs] = top[idx];
dfs2(idx, bs);
bottom[idx] = max(bottom[idx], bottom[bs]);
}
for (int ne = head[idx]; ne != 0; ne = e[ne].nxt)
{
int v = e[ne].to;
if (v == f || v == son[idx])
continue;
top[v] = v;
dfs2(idx, v);
bottom[idx] = max(bottom[idx], bottom[v]);
}
}
void LCAadd(SegT &st, int u, int v, ll val)
{
while (top[u] != top[v]) // 如果两个节点不在同一条树链上,那么就需要跳跃
{
// 跳跃从树链头节点深度较大的开始跳
if (dep[top[u]] < dep[top[v]])
{
swap(u, v);
}
st.add(1, dfn[top[u]], dfn[u] + 1, val);
u = fa[top[u]];
}
if (dfn[u] > dfn[v])
swap(u, v);
st.add(1, dfn[u], dfn[v] + 1, val);
}
ll LCAquery(SegT &st, int u, int v, ll mod)
{
ll ans = 0;
while (top[u] != top[v]) // 如果两个节点不在同一条树链上,那么就需要跳跃
{
// 跳跃从树链头节点深度较大的开始跳
if (dep[top[u]] < dep[top[v]])
{
swap(u, v);
}
ans = (ans + st.query(1, dfn[top[u]], dfn[u] + 1)) % mod;
u = fa[top[u]];
}
if (dfn[u] > dfn[v])
swap(u, v);
ans = (ans + st.query(1, dfn[u], dfn[v] + 1)) % mod;
return ans;
}
void TREEadd(SegT &st, int u, ll val)
{
int v = bottom[u];
st.add(1, dfn[u], v + 1, val);
}
ll TREEquery(SegT &st, int u)
{
int v = bottom[u];
return st.query(1, dfn[u], v + 1);
}
int main()
{
int n, m, r;
ll p;
cin >> n >> m >> r >> p;
for (int i = 1; i <= n; i++)
{
cin >> init[i];
}
for (int i = 0; i < n - 1; i++)
{
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
dep[r] = 1;
dfs1(r, r);
dfs2(r, r);
for (int i = 1; i <= n; i++)
{
init1[dfn[i]] = init[i];
}
ST.init(init1, n, p);
while (m--)
{
int op;
cin >> op;
int u, v, x;
ll val;
switch (op)
{
case 1:
cin >> u >> v >> val;
val %= p;
LCAadd(ST, u, v, val);
break;
case 2:
cin >> u >> v;
cout << (LCAquery(ST, u, v, p) % p) << endl;
break;
case 3:
cin >> x >> val;
val %= p;
TREEadd(ST, x, val);
break;
case 4:
cin >> x;
cout << (TREEquery(ST, x) % p) << endl;
break;
default:
break;
}
}
return 0;
}
例题
判断树上两个路径是否相交,有如下定理:
树上两个路径相交,当且仅当存在一条路径上的起点和终点的LCA节点是另外一条路径上的一个节点。
因此此题为求LCA问题,并且判断点是否在路径上,树链剖分即可。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
#define FR freopen("in.txt", "r", stdin)
#define FW freopen("out1.txt", "w", stdout)
struct Edge
{
int to;
int nxt;
} e[200005];
int head[100005];
int tot = 0;
int n, q;
int fa[100005];
int wi[100005];
int wson[100005];
int dep[100005];
int dfn[100005];
int fn = 0;
int top[100005];
void add(int u, int v)
{
tot++;
e[tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
void dfs1(int u, int root)
{
dep[u] = dep[root] + 1;
fa[u] = root;
wi[u] = 1;
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == root)
continue;
dfs1(to, u);
wi[u] += wi[to];
if (wi[wson[u]] < wi[to])
wson[u] = to;
}
}
void dfs2(int u, int root)
{
fn++;
dfn[u] = fn;
if (wson[u] != 0)
{
top[wson[u]] = top[u];
dfs2(wson[u], u);
}
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == root || to == wson[u])
continue;
top[to] = to;
dfs2(to, u);
}
}
int LCA(int a, int b)
{
while (top[a] != top[b])
{
if (dep[top[a]] > dep[top[b]])
{
a = fa[top[a]];
}
else
{
b = fa[top[b]];
}
}
if (dep[a] < dep[b])
{
return a;
}
else
{
return b;
}
}
bool LCAC(int a, int b, int c)
{
while (top[a] != top[b])
{
if (dep[top[a]] > dep[top[b]])
{
if (dfn[c] <= dfn[a] && dfn[c] >= dfn[top[a]])
{
return true;
}
a = fa[top[a]];
}
else
{
if (dfn[c] <= dfn[b] && dfn[c] >= dfn[top[b]])
{
return true;
}
b = fa[top[b]];
}
}
if (dep[a] > dep[b])
{
if (dfn[c] <= dfn[a] && dfn[c] >= dfn[b])
{
return true;
}
}
else
{
if (dfn[c] <= dfn[b] && dfn[c] >= dfn[a])
{
return true;
}
}
return false;
}
void loop(int u)
{
printf("%d\n", u);
while (fa[u] != 1)
{
printf("%d\n", u = fa[u]);
}
printf("1\n");
}
int main()
{
scanf("%d %d", &n, &q);
for (int i = 0; i < n - 1; i++)
{
int u, v;
scanf("%d %d", &u, &v);
add(u, v);
add(v, u);
}
dfs1(1, 1);
top[1] = 1;
dfs2(1, 1);
while (q--)
{
int a, b, c, d;
scanf("%d %d %d %d", &a, &b, &c, &d);
int lcaA = LCA(a, b);
int lcaB = LCA(c, d);
printf("%c\n", LCAC(a, b, lcaB) || LCAC(c, d, lcaA) ? 'Y' : 'N');
}
return 0;
}