树链剖分
-
问题引入:
操作 1 1 1
- 将树从 x x x 到 y y y 结点最短路径上所有节点的值都加上 z z z
操作 2 2 2
- 求树从 x x x 到 y y y 结点最短路径上所有节点的值之和
两个操作分开很好解决,第一种用树上差分,第二种用 l c a lca lca
-
引入新方法:树剖是通过轻重边剖分将树分割成多条链,然后利用数据结构来维护这些链
- 重儿子:父亲节点的所有儿子中子树结点数目最多( s i z e size size 最大)的结点;
- 轻儿子:父亲节点中除了重儿子以外的儿子;
- 重边:父亲结点和重儿子连成的边;
- 轻边:父亲节点和轻儿子连成的边;
- 重链:由多条重边连接而成的路径;
- 轻链:由多条轻边连接而成的路径;
f [ u ] u 的 父 亲 节 点 d [ u ] u 的 深 度 s i z e [ u ] 以 u 为 根 的 子 树 节 点 数 s o n [ u ] u 的 重 儿 子 r k [ u ] 保 存 当 前 d f s 序 号 在 树 中 对 应 的 节 点 t o p [ u ] 保 存 当 前 点 所 在 链 上 的 起 点 i d [ u ] 保 存 树 中 每 个 点 剖 分 之 后 的 新 编 号 ( d f s 序 ) \begin{aligned} & f[u] \qquad u的父亲节点 \\ & d[u] \qquad u的深度 \\ & size[u] \qquad 以u为根的子树节点数 \\ & son[u] \qquad u的重儿子 \\ & rk[u] \qquad 保存当前 dfs 序号在树中对应的节点 \\ & top[u] \qquad 保存当前点所在链上的起点 \\ & id[u] \qquad 保存树中每个点剖分之后的新编号(dfs序) \end{aligned} f[u]u的父亲节点d[u]u的深度size[u]以u为根的子树节点数son[u]u的重儿子rk[u]保存当前dfs序号在树中对应的节点top[u]保存当前点所在链上的起点id[u]保存树中每个点剖分之后的新编号(dfs序)
先由 d f s dfs dfs 求出 f [ ] , s o n [ ] , d [ ] , s i z e [ ] f[],son[],d[],size[] f[],son[],d[],size[] 数组
void dfs1(int x, int fa, int dep) { f[x] = fa; size[x] = 1; depth[x] = dep; for (int i = head[x]; i; i = Next[i]) { int y = ver[i]; if (y == fa) continue; dfs(y, x, dep + 1); size[x] += size[y]; if (size[y] > size[son[x]]) son[x] = y; } }
再 d f s dfs dfs 求 t o p [ ] , i d [ ] , r k [ ] top[],id[],rk[] top[],id[],rk[] 数组
void dfs2(int x, int tp) { top[x] = tp, id[x] = ++cnt, rk[cnt] = x; if (son[x]) dfs2(son[x], tp); // 先搜索重儿子的那条边,保证dfs序连续 for (int v, i = head[x]; i; i = Next[i]) if ((v = ver[i]) != fa[x] && v != son[x]) dfs2(v, v); }
最后再由两个要修改的点向重儿子链上修改,可以证明树链剖分的复杂度是 O ( l o g ( n ) ) O(log(n)) O(log(n)) ,里面再套一个线段树,所以整体复杂度是 O ( l o g ( n ) ) 2 O(log(n))^2 O(log(n))2
inline ll sum(int x, int y) { ll ret = 0; while (top[x] != top[y]) { if (d[top[x]] < d[top[y]]) //选择深度较深的那一个点修改 swap(x, y); (ret += query(1, id[top[x]], id[x])) %= mod; x = fa[top[x]]; } if (id[x] > id[y]) swap(x, y); return (ret + query(1, id[x], id[y])) % mod; }
最后洛谷模板题
最后要注意的是是对于线段树我们建的是 d f s dfs dfs 序的树,所以赋初值的时候要按 d f s dfs dfs 序的下标赋初值。
最后代码:
const int N = 1e5 + 5; int head[N], ver[N << 1], Next[N << 1]; int size[N], d[N], son[N], fa[N], cnt, top[N], rk[N], id[N]; ll mod, tot, n, m, r; ll a[N]; void add(int x, int y) { ver[++tot] = y; Next[tot] = head[x]; head[x] = tot; } struct stree { int l; int r; ll sum; ll add; } tree[N << 2]; void pushup(int rt) { tree[rt].sum = (tree[rt << 1].sum + tree[rt << 1 | 1].sum) % mod; } void pushdown(int rt) { if (tree[rt].add) { tree[rt << 1].sum += tree[rt].add * (tree[rt << 1].r - tree[rt << 1].l + 1); tree[rt << 1].sum %= mod; tree[rt << 1 | 1].sum += tree[rt].add * (tree[rt << 1 | 1].r - tree[rt << 1 | 1].l + 1); tree[rt << 1 | 1].sum %= mod; tree[rt << 1].add += tree[rt].add; tree[rt << 1].add %= mod; tree[rt << 1 | 1].add += tree[rt].add; tree[rt << 1 | 1].add %= mod; tree[rt].add = 0 ; } } void build(int rt,int l,int r) { tree[rt].l = l; tree[rt].r = r; if(l == r) { tree[rt].sum = a[rk[l]]; return; } int mid = (l+r)>>1; pushdown(rt); build(rt<<1,l,mid); build(rt<<1|1,mid+1,r); pushup(rt); } void update(int rt, int L, int R, int val) { int l = tree[rt].l; int r = tree[rt].r; if (L <= l && r <= R) { tree[rt].sum += (ll)val * (tree[rt].r - tree[rt].l + 1); tree[rt].sum %= mod; tree[rt].add += val; tree[rt].add %= mod; return ; } int mid = (l + r) >> 1; pushdown(rt); if (L <= mid) update(rt << 1, L, R, val); if (R > mid) update(rt << 1 | 1, L, R, val); pushup(rt); } ll query(int rt, int L, int R) { int l = tree[rt].l; int r = tree[rt].r; if (L <= l && r <= R) return tree[rt].sum % mod; int mid = (l + r) >> 1; pushdown(rt); ll ans = 0; if (L <= mid) ans = ans + query(rt << 1, L, R) % mod; if (R > mid) ans = ans + query(rt << 1 | 1, L, R) % mod; return ans % mod; } //------------------------------------------------以上为线段树部分 void dfs1(int x) { size[x] = 1, d[x] = d[fa[x]] + 1; for (int v, i = head[x]; i; i = Next[i]) if ((v = ver[i]) != fa[x]) { fa[v] = x, dfs1(v), size[x] += size[v]; if (size[son[x]] < size[v]) son[x] = v; } } void dfs2(int x, int tp) { top[x] = tp, id[x] = ++cnt, rk[cnt] = x; if (son[x]) dfs2(son[x], tp); for (int v, i = head[x]; i; i = Next[i]) if ((v = ver[i]) != fa[x] && v != son[x]) dfs2(v, v); } inline ll sum(int x, int y) { ll ret = 0; while (top[x] != top[y]) { if (d[top[x]] < d[top[y]]) swap(x, y); (ret += query(1, id[top[x]], id[x])) %= mod; x = fa[top[x]]; } if (id[x] > id[y]) swap(x, y); return (ret + query(1, id[x], id[y])) % mod; } inline void change(int x, int y, int c) { while (top[x] != top[y]) { if (d[top[x]] < d[top[y]]) swap(x, y); update(1, id[top[x]], id[x], c); x = fa[top[x]]; } if (id[x] > id[y]) swap(x, y); update(1, id[x], id[y], c); } int main () { //freopen("input.in", "r", stdin); //freopen("test.out", "w", stdout); read(n); read(m); read(r); read(mod); for (int i = 1; i <= n; i++) read(a[i]); for (int i = 1; i < n; i++) { int u, v; read(u); read(v); add(u, v); add(v, u); } dfs1(r); dfs2(r, r); build(1, 1, n); for (int i = 1; i <= m; i++) { int op; read(op); if (op == 1) { ll x, y, z; read(x); read(y); read(z); change(x, y, z); } else if (op == 2) { ll x, y; read(x); read(y); ll ans = sum(x, y); out(ans); puts(""); } else if (op == 3) { ll x, y; read(x); read(y); update(1, id[x], id[x] + size[x] - 1, y); } else { ll x; read(x); ll ans = query(1, id[x], id[x] + size[x] - 1); out(ans); puts(""); } } return 0 ; }
-
参考博客