什么是树链剖分
树链剖分是把一棵树分割成若干条链,以便于维护信息的一种方法,其中最常用的是重链剖分,所以一般提到树链剖分或树剖都是指重链剖分。除此之外还有长链剖分和实链剖分等。
我们定义树上一个节点的子节点中子树最大的一个为它的重子节点,其余的为轻子节点。
一个节点连向其重子节点的边称为重边,连向轻子节点的边则为轻边。
如果把根节点看作轻节点的,那么从每个轻节点出发,不断向下走重边,都对应了一条链,于是我们把树剖分成了 l l l 条链,其中 l l l 是轻节点的数量。
详见下图:
图中有 2 2 2 条重链, 4 4 4 条轻链。
重链剖分有一个重要的性质:对于节点数为 n n n 的树,从任意节点向上走到根节点,经过的轻边数量不超过 log n \log n logn。
这是因为,如果一个节点连向父节点的边是轻边,就必然存在子树不小于它的兄弟节点,那么父节点对应子树的大小一定超过该节点的两倍。
所以每经过一条轻边,子树大小就翻倍,所以最多只能经过 log n \log n logn 条。
以上是时间复杂度的证明。
树链剖分的实现
我们通过两次 dfs
来进行重链剖分。
第一趟 dfs
,先得到每个节点的
f
a
fa
fa(父节点)、
s
i
z
siz
siz(子树大小)、
d
e
p
dep
dep(深度)、
h
s
o
n
hson
hson(重子节点):
void dfs1(int u, int d = 1)
{
int mx = 0;
siz[u] = 1;
dep[u] = d;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (!dep[v])
{
dfs1(v, d + 1);
fa[v] = u;
siz[u] += siz[v];
if (mx < siz[v])
{
hson[u] = v;
mx = siz[v];
}
}
}
return;
}
第二趟 dfs
,得到每个节点的 top
(链头,即所在的重链中深度最小的那个节点)和每个点的 dfs
序:
void dfs2(int u, int topf)
{
dfsn[u] = ++cnt;
top[u] = topf;
if (hson[u])
{
dfs2(hson[u], topf);
}
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (!top[v])
{
dfs2(v, v);
}
}
return;
}
为什么要将一棵树看成多条链呢 ?
其实这样可以作很多的区间操作,
因为我们在求链头的时候顺便求了 dfs
序。
所以对于一条链,其上的所有节点的 dfs
序都是连续的,所以所有链其实对应着一个完整的区间,我们通常用线段树或树状数组来维护这些链。
求最近公共祖先
树链剖分对于最近公共祖先的求法其实和倍增很相似,倍增每次往上跳 2 2 2 的 k k k 次方次,而树链剖分则每次往上跳一条链。
时间复杂度上面已经证明了,为 O ( log n ) O(\log n) O(logn)。
假如要求 u u u 和 v v v 的最近公共祖先
每次选择链头深度大的点跳,直到将两个点跳到同一条链上,最近公共祖先就是两个点中深度较小的那个点。
inline int lca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
u = fa[top[u]];
}
else
{
v = fa[top[v]];
}
}
return dep[u] > dep[v] ? v : u;
}
维护两点间的简单路径
其实和求最近公共祖先的方法类似。
每跳一条链时维护链上的信息就行了。
void update_path(int u, int v, int z)
{
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
update(1, 1, n, dfsn[top[u]], dfsn[u], z);
u = fa[top[u]];
}
else
{
update(1, 1, n, dfsn[top[v]], dfsn[v], z);
v = fa[top[v]];
}
}
if (dep[u] > dep[v])
{
update(1, 1, n, dfsn[v], dfsn[u], z);
}
else
{
update(1, 1, n, dfsn[u], dfsn[v], z);
}
}
int query_path(int u, int v)
{
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
res += query(1, 1, n, dfsn[top[u]], dfsn[u]);
u = fa[top[u]];
}
else
{
res += query(1, 1, n, dfsn[top[v]], dfsn[v]);
v = fa[top[v]];
}
}
if (dep[u] > dep[v])
{
res += query(1, 1, n, dfsn[v], dfsn[u]);
}
else
{
res += query(1, 1, n, dfsn[u], dfsn[v]);
}
return res;
}
维护某一点的子树
这次比较简单。
因为一个节点的子树上的所有节点的 dfs
序肯定是连续的,简单维护就行了。
区间为 [ d f n u , d f n u + s i z u − 1 ] [dfn_u,dfn_u+siz_u-1] [dfnu,dfnu+sizu−1]。
void update_subtree(int u, int z)
{
return update(1, 1, n, dfsn[u], dfsn[u] + siz[u] - 1, z);
}
int query_subtree(int u)
{
return query(1, 1, n, dfsn[u], dfsn[u] + siz[u] - 1);
}
模板
#include <bits/stdc++.h>
#define int long long
#define _ 500005
#define ls(a) a << 1
#define rs(a) a << 1 | 1
using namespace std;
int n, m, r, p;
int op, x, y, z;
int a[_], b[_];
int head[_], to[_], nxt[_], tot;
int tree[_], lazy[_];
int fa[_], dep[_], siz[_], top[_], hson[_], dfsn[_], cnt;
//int mxdfsn[_];
void push_up(int o)
{
tree[o] = tree[ls(o)] + tree[rs(o)];
}
void push_down(int o, int l, int r)
{
if (lazy[o])
{
int mid = (l + r) >> 1;
tree[ls(o)] += lazy[o] * (mid - l + 1);
tree[rs(o)] += lazy[o] * (r - (mid + 1) + 1);
lazy[ls(o)] += lazy[o];
lazy[rs(o)] += lazy[o];
lazy[o] = 0;
}
}
void build(int o, int l, int r)
{
if (l == r)
{
tree[o] = a[l];
return;
}
int mid = (l + r) >> 1;
build(ls(o), l, mid);
build(rs(o), mid + 1, r);
push_up(o);
}
void update(int o, int l, int r, int L, int R, int val)
{
if (L <= l && r <= R)
{
tree[o] += (r - l + 1) * val;
lazy[o] += val;
return;
}
int mid = (l + r) >> 1;
push_down(o, l, r);
if (L <= mid)
{
update(ls(o), l, mid, L, R, val);
}
if (R > mid)
{
update(rs(o), mid + 1, r, L, R, val);
}
push_up(o);
}
int query(int o, int l, int r, int L, int R)
{
if (L <= l && r <= R)
{
return tree[o];
}
int mid = (l + r) >> 1;
int res = 0;
push_down(o, l, r);
if (L <= mid)
{
res += query(ls(o), l, mid, L, R);
}
if (R > mid)
{
res += query(rs(o), mid + 1, r, L, R);
}
return res;
}
inline void add(int a, int b)
{
to[++tot] = b;
nxt[tot] = head[a];
head[a] = tot;
}
void dfs1(int u, int d = 1)
{
int mx = 0;
siz[u] = 1;
dep[u] = d;
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (!dep[v])
{
dfs1(v, d + 1);
fa[v] = u;
siz[u] += siz[v];
if (mx < siz[v])
{
hson[u] = v;
mx = siz[v];
}
}
}
return;
}
void dfs2(int u, int topf)
{
//mxdfsn[u] =
dfsn[u] = ++cnt;
top[u] = topf;
if (hson[u])
{
dfs2(hson[u], topf);
//mxdfsn[u] = max(mxdfsn[u], mxdfsn[hson[u]]);
}
for (int i = head[u]; i; i = nxt[i])
{
int v = to[i];
if (!top[v])
{
dfs2(v, v);
//mxdfsn[u] = max(mxdfsn[u], mxdfsn[v]);
}
}
return;
}
inline int lca(int u, int v)
{
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
u = fa[top[u]];
}
else
{
v = fa[top[v]];
}
}
return dep[u] > dep[v] ? v : u;
}
void update_path(int u, int v, int z)
{
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
update(1, 1, n, dfsn[top[u]], dfsn[u], z);
u = fa[top[u]];
}
else
{
update(1, 1, n, dfsn[top[v]], dfsn[v], z);
v = fa[top[v]];
}
}
if (dep[u] > dep[v])
{
update(1, 1, n, dfsn[v], dfsn[u], z);
}
else
{
update(1, 1, n, dfsn[u], dfsn[v], z);
}
}
int query_path(int u, int v)
{
int res = 0;
while (top[u] != top[v])
{
if (dep[top[u]] > dep[top[v]])
{
res += query(1, 1, n, dfsn[top[u]], dfsn[u]);
u = fa[top[u]];
}
else
{
res += query(1, 1, n, dfsn[top[v]], dfsn[v]);
v = fa[top[v]];
}
}
if (dep[u] > dep[v])
{
res += query(1, 1, n, dfsn[v], dfsn[u]);
}
else
{
res += query(1, 1, n, dfsn[u], dfsn[v]);
}
return res;
}
void update_subtree(int u, int z)
{
return update(1, 1, n, dfsn[u], dfsn[u] + siz[u] - 1, z); //update(1, 1, n, dfsn[u], mxdfsn[u], z);
}
int query_subtree(int u)
{
return query(1, 1, n, dfsn[u], dfsn[u] + siz[u] - 1); //query(1, 1, n, dfsn[u], mxdfsn[u]);
}
signed main()
{
scanf("%lld%lld%lld%lld", &n, &m, &r, &p);
for (int i = 1; i <= n; ++i)
{
scanf("%lld", &b[i]);
}
for (int i = 1; i < n; i++)
{
scanf("%lld%lld", &x, &y);
add(x, y);
add(y, x);
}
dfs1(r);
dfs2(r, r);
for (int i = 1; i <= n; ++i)
{
a[dfsn[i]] = b[i];
}
build(1, 1, n);
for (int i = 1; i <= m; ++i)
{
scanf("%lld", &op);
if (op == 1)
{
scanf("%lld%lld%lld", &x, &y, &z);
update_path(x, y, z);
}
else if (op == 2)
{
scanf("%lld%lld", &x, &y);
printf("%lld\n", query_path(x, y) % p);
}
else if (op == 3)
{
scanf("%lld%lld", &x, &z);
update_subtree(x, z);
}
else
{
scanf("%lld", &x);
printf("%lld\n", query_subtree(x) % p);
}
}
return 0;
}