前言
树链剖分将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。
本文讨论常见的重链剖分
各定义来自 OI Wiki 树链剖分
重子节点 子节点中子树最大的子结点
轻子节点 表示剩余的所有子结点
重边 从这个结点到重子节点的边
轻边 到其他轻子节点的边
若干条首尾衔接的重边构成 重链
[fa(x)] 表示节点 [x] 在树上的父亲。
[dep(x)] 表示节点 [x] 在树上的深度。
[siz(x)] 表示节点 [x] 的子树的节点个数。
[son(x)] 表示节点 [x] 的 重儿子。
[top(x)] 表示节点 [x] 所在 重链 的顶部节点(深度最小)。
[dfn(x)] 表示节点 [x] 的 DFS 序,也是其在线段树中的编号。
[rnk(x)] 表示 DFS 序所对应的节点编号,有 [rnk(dfn(x))=x] 。
树剖模板
题目:P3384 【模板】重链剖分/树链剖分
参考:题解 P3384 【【模板】树链剖分】
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5, T = 4e5 + 5;
#define lc (x << 1)
#define rc (x << 1 | 1)
vector<int> e[N];
void addEdge(int x, int y)
{
e[x].push_back(y);
e[y].push_back(x);
}
int sz[N], top[N], dfn[N], fa[N], dep[N], son[N];
int cnt = 0, a[N], b[N];//a[i]是dfs序为i的节点的值,b[i]是节点号为i的节点的值
int n, m, rt, P;
void dfs1(int x, int f)//预处理深度,父亲,重儿子,子树大小
{
fa[x] = f;
dep[x] = dep[f] + 1;
son[x] = 0;
sz[x] = 1;
for (auto y : e[x])
{
if (y == f) continue;
dfs1(y, x);
if (sz[y] > sz[son[x]]) son[x] = y;
sz[x] += sz[y];
}
}
//注意tp的取值...
void dfs2(int x, int tp)//预处理出top,dfs序和按dfs序变换后的数组
{
top[x] = tp;
dfn[x] = ++cnt;
a[cnt] = b[x];
if (!son[x]) return;
dfs2(son[x], tp);//先走重儿子,这样一条重链上的dfs序才会连续
for (auto y : e[x])
{
if (dfn[y]) continue;//如果走过(包括父亲和重儿子
dfs2(y, y);
}
}
struct Tree
{
int l, r, sz;
ll add, sum;
}t[T];
void pushup(int x)
{
t[x].sum = (t[lc].sum + t[rc].sum) % P;
}
void pushdown(int x)
{
if (!t[x].add) return;
t[lc].add = (t[lc].add + t[x].add) % P;
t[rc].add = (t[rc].add + t[x].add) % P;
t[lc].sum = (t[lc].sum + t[lc].sz * t[x].add % P) % P;
t[rc].sum = (t[rc].sum + t[rc].sz * t[x].add % P) % P;
t[x].add = 0;
}
void build(int x, int l, int r)
{
t[x].l = l, t[x].r = r;
t[x].sz = r - l + 1;
t[x].add = 0;
if (l == r)
{
t[x].sum = a[l] % P;
return;
}
int mid = l + r >> 1;
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(x);
}
void add(int x, int l, int r, int val)
{
if (l <= t[x].l && r >= t[x].r)
{
t[x].add = (t[x].add + val) % P;
t[x].sum = (t[x].sum + t[x].sz * val % P) % P;
return;
}
pushdown(x);
int mid = t[x].l + t[x].r >> 1;
if (l <= mid) add(lc, l, r, val);
if (r > mid) add(rc, l, r, val);
pushup(x);
}
ll query(int x, int l, int r)
{
if (l <= t[x].l && r >= t[x].r)
{
return t[x].sum;
}
pushdown(x);
ll res = 0;
int mid = t[x].l + t[x].r >> 1;
if (l <= mid) res = query(lc, l, r);
if (r > mid) res = (res + query(rc, l, r)) % P;
return res;
}
void addTree(int x, int y, int val)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);//比较xy的链头哪个在上面
add(1, dfn[top[x]], dfn[x], val);//给整条链修改
x = fa[top[x]];//在下面的往上跳出这条链
}
//跳到同一链上了
if (dep[x] < dep[y]) swap(x, y);
add(1, dfn[y], dfn[x], val);
}
ll queryTree(int x, int y)
{
ll res = 0;
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
res = (res + query(1, dfn[top[x]], dfn[x])) % P;
x = fa[top[x]];
}
if (dep[x] < dep[y]) swap(x, y);
res = (res + query(1, dfn[y], dfn[x])) % P;
return res;
}
int main()
{
cin >> n >> m >> rt >> P;
for (int i = 1; i <= n; i++)
cin >> b[i];
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
addEdge(x, y);
}
dfs1(rt, 0);
dfs2(rt, rt);
build(1, 1, n);
for (int i = 0; i < m; i++)
{
int k, x, y, z;
cin >> k;
if (k == 1)
{
cin >> x >> y >> z;
addTree(x, y, z % P);
}
else if (k == 2)
{
cin >> x >> y;
cout << queryTree(x, y) << endl;
}
else if (k == 3)
{
cin >> x >> z;
add(1, dfn[x], dfn[x] + sz[x] - 1, z % P);
//这个x的子树的表示方法我也要学学...好厉害
}
else
{
cin >> x;
cout << query(1, dfn[x], dfn[x] + sz[x] - 1) << endl;
}
}
return 0;
}
树剖求lca
实际上就是借用树剖实现让点往上跳的功能,当两个点跳到同一条链上时,在上面的那个点就是lca了
(结合代码食用)
感性(瞎想)理解一下:
想象一棵以lca(x,y)为根的子树,x和y就在这棵树的某条链上。假设现在x、y已经跳了若干步,都在以lca或者lca的儿子节点为top的链上。那么现在有两种情况:
- x在(lca所在的)重链上:容易得出
top[x]
就是top[lca]
,top[y]
是lca的一个轻儿子,那么有dep[top[x]] < dep[top[y]]
,y向上跳到lca,此时xy同链,y在上 - x和y都不在重链上:那么
top[x]
top[y]
都是lca的某个子节点,深度一样,假设x向上跳到lca,此时xy还不同链,回到第一种情况
任意情况都可以跳成上面的情况(应该是吧我感觉没问题但是脑子好乱
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
typedef long long ll;
const int N = 500005;
int dep[N], fa[N], top[N], son[N], sz[N];
vector<int> e[N];
void addEdge(int x, int y)
{
e[x].push_back(y);
e[y].push_back(x);
}
void dfs1(int x, int f)
{
fa[x] = f;
dep[x] = dep[f] + 1;
sz[x] = 1;
for (auto y : e[x])
{
if (y == f) continue;
dfs1(y, x);
if (sz[y] > sz[son[x]]) son[x] = y;
sz[x] += sz[y];
}
}
void dfs2(int x, int tp)
{
top[x] = tp;
if (!son[x]) return;
dfs2(son[x], tp);
for (auto y : e[x])
{
if (y == fa[x] || y == son[x]) continue;
dfs2(y, y);
}
}
int lca(int x, int y)
{
while (top[x] != top[y])
{
if (dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
return dep[x] < dep[y] ? x : y;
}
int main()
{
int n, m, rt;
cin >> n >> m >> rt;
for (int i = 1; i < n; i++)
{
int x, y;
cin >> x >> y;
addEdge(x, y);
}
dfs1(rt, 0);
dfs2(rt, rt);
for (int i = 0; i < m; i++)
{
int a, b;
cin >> a >> b;
cout << lca(a, b) << endl;
}
return 0;
}
好了先学到这里(颤颤巍巍断断续续啃了两星期