前言
树剖是连接图论和数据结构的一道桥梁,绝大多数数据结构都是建立在线性表中,
与二维的图天然隔绝,而树剖则提供了一个很好的转换方法
可以将二维的图转化为一维的表,以重链选择较多,长链也有
利用dfn序将树存储与线段树中,这个性质很重要:dfn序连续的节点在树链中也连续
树剖的代码相对好理解一些,不过代码比较长,其大致实现思路是:
- 进行第一次dfs遍历,求出以下数组:pre[], deep[], num[], son[], 即父节点、深度数组、子树大小和重子节点,此次dfs为无差别搜索
- 然后进行第二次dfs遍历,记录以下数组:dfn[], top[], rk[],即dfn序,链首节点和链内排名,此次dfs为优先重子搜索
至此,便可以利用dfn序和线段树结合吗,共同维护此树
一、 例题 p3384
题目链接 洛谷 p3384
二、 思路及代码
1. 思路
很经典的一道树剖与线段树的结合,套板子即可
2. 代码
代码如下 :
#include <cstring>
#include <iostream>
// #define int long long 会RE ???
using namespace std;
const int maxn = 2e5 + 5;
int mod;
int n, m, r;
int pre[maxn], deep[maxn], num[maxn], son[maxn]; // 第一次dfs的变量
int dfn[maxn], top[maxn], rk[maxn], id; // 第二次dfs的时候用到的变量
int val[maxn], dfnval[maxn]; // 建树时用dfn序下的val
struct e {
int to, next;
} edge[maxn << 1];
int head[maxn], cnt;
void init() {
cnt = 1, id = 0;
memset(head, -1, sizeof(head));
for (int i = 0; i < maxn; i++) edge[i].next = -1;
}
struct segtree {
int l, r;
int val, add;
} t[maxn << 2];
void addedge(int u, int v) {
edge[cnt] = e{v, head[u]};
head[u] = cnt++;
}
void predfs(int u, int fa, int d) {
deep[u] = d, num[u] = 1, pre[u] = fa;
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (v == fa) continue;
predfs(v, u, d + 1);
num[u] += num[v]; // 子树的节点数
if (num[son[u]] < num[v]) son[u] = v; // 重链子树的根节点
}
}
void dfs(int u, int t) {
dfn[u] = ++id, dfnval[id] = val[u], top[u] = t, rk[id] = u;
if (!son[u]) return; // 按重链顺序进行dfs
dfs(son[u], t);
for (int i = head[u]; i != -1; i = edge[i].next) {
int v = edge[i].to;
if (v == pre[u] || v == son[u]) continue;
dfs(v, v); // 更新链首值top为v
}
}
void build(int root, int l, int r) {
t[root].l = l, t[root].r = r;
if (l == r) {
t[root].val = dfnval[r];
return;
}
int mid = (l + r) >> 1;
build(root * 2, l, mid);
build(root * 2 + 1, mid + 1, r);
t[root].val = (t[root * 2].val + t[root * 2 + 1].val) % mod;
}
void spread(int p) {
int l = t[p].l, r = t[p].r;
if (t[p].add) {
t[p * 2].add = (t[p * 2].add + t[p].add) % mod;
t[p * 2 + 1].add = (t[p * 2 + 1].add + t[p].add) % mod;
t[p * 2].val =
(t[p * 2].val + (t[p * 2].r - t[p * 2].l + 1) * t[p].add) % mod;
t[p * 2 + 1].val =
(t[p * 2 + 1].val + (t[p * 2 + 1].r - t[p * 2 + 1].l + 1) * t[p].add) %
mod;
t[p].add = 0;
}
}
void update(int root, int l, int r, int x) {
if (l <= t[root].l && t[root].r <= r) {
t[root].add = (t[root].add + x) % mod;
t[root].val = (t[root].val + (t[root].r - t[root].l + 1) * x) % mod;
return;
}
spread(root);
int mid = (t[root].r + t[root].l) >> 1;
if (l <= mid) update(root * 2, l, r, x);
if (mid < r) update(root * 2 + 1, l, r, x);
t[root].val = (t[root * 2].val + t[root * 2 + 1].val) % mod;
}
int query(int root, int l, int r) {
if (l <= t[root].l && t[root].r <= r) return t[root].val;
spread(root);
int res = 0;
int mid = (t[root].l + t[root].r) >> 1;
if (l <= mid) res = (res + query(root * 2, l, r)) % mod;
if (mid < r) res = (res + query(root * 2 + 1, l, r)) % mod;
return res;
}
void updatepath(int u, int v, int x) {
while (top[u] != top[v]) { // 更新不同链上的节点
if (deep[top[u]] < deep[top[v]]) swap(u, v);
update(1, dfn[top[u]], dfn[u], x);
u = pre[top[u]]; // 优先更新链首深度大的节点,更快迭代
}
if (deep[u] < deep[v]) swap(u, v); // 更新相同链上的节点
update(1, dfn[v], dfn[u], x);
}
int querypath(int u, int v) {
int res = 0;
while (top[u] != top[v]) { // 查询不同链上的节点
if (deep[top[u]] < deep[top[v]]) swap(u, v);
res = (res + query(1, dfn[top[u]], dfn[u])) % mod;
u = pre[top[u]]; // 优先查询链首深度大的节点,更快迭代
}
if (deep[u] < deep[v]) swap(u, v); // 查询相同链上的节点
res = (res + query(1, dfn[v], dfn[u])) % mod;
return res;
}
signed main() {
// freopen("in.txt", "r", stdin);
// freopen("out.txt", "w", stdout);
init();
scanf("%d%d%d%d", &n, &m, &r, &mod);
for (int i = 1; i <= n; i++) scanf("%d", &val[i]);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
addedge(u, v), addedge(v, u);
}
predfs(r, 0, 1);
dfs(r, r);
build(1, 1, n);
while (m--) {
int t, u, v, w;
scanf("%d", &t);
if (t == 1) {
scanf("%d%d%d", &u, &v, &w);
updatepath(u, v, w);
} else if (t == 2) {
scanf("%d%d", &u, &v);
printf("%d\n", querypath(u, v));
} else if (t == 3) {
scanf("%d%d", &u, &w);
update(1, dfn[u], dfn[u] + num[u] - 1, w);
} else {
scanf("%d", &u);
printf("%d\n", query(1, dfn[u], dfn[u] + num[u] - 1));
}
}
return 0;
}