题意
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作 1 1 1: 格式: 1 1 1 x x x y y y z z z 表示将树从 x x x到 y y y结点最短路径上所有节点的值都加上 z z z
操作 2 2 2: 格式: 2 2 2 x x x y y y 表示求树从 x x x到 y y y结点最短路径上所有节点的值之和
操作 3 3 3: 格式: 3 3 3 x x x z z z 表示将以 x x x为根节点的子树内所有节点值都加上 z z z
操作 4 4 4: 格式: 4 4 4 x x x 表示求以 x x x为根节点的子树内所有节点值之和
思路
树链剖分,就是把一棵树剖成若干链,然后在树上的操作就转变成在序列里的操作,如何剖是关键。
定义如下数组:
f
a
t
h
e
r
father
father,记录节点的me父亲
d
e
p
dep
dep,记录节点的深度
s
i
z
e
size
size,记录节点的子树大小
s
o
n
son
son,记录节点的重儿子,就是该节点
s
i
z
e
size
size最大的儿子
r
e
v
rev
rev,记录这个
d
f
s
dfs
dfs序对应的原来的节点
t
o
p
top
top,记录当前点处在的重链的顶端
s
e
g
seg
seg,记录当前点的
d
f
s
dfs
dfs序
2
2
2次
d
f
s
dfs
dfs即可求出,计算
s
e
g
seg
seg时,同一条重链上的点需要按顺序排在连续的一段区间。
这样子把树拆成若干条链,用线段树维护其。
树链剖分的两个性质:
1 , 1, 1,如果 ( u , v ) (u,v) (u,v)是一条轻边,那么 s i z e ( v ) ≤ s i z e ( u ) / 2 size(v)\leq size(u)/2 size(v)≤size(u)/2;
2 , 2, 2,从根结点到任意结点的路所经过的轻重链的个数必定都小于 l o g n logn logn;
因此树链剖分的总复杂度是 O ( l o g 2 n ) O(log^2n) O(log2n)的。
对于本题的
4
4
4种操作:
1
、
1、
1、若两点不在同一条重链上,将
t
o
p
top
top较深的点往
f
a
t
h
e
r
[
t
o
p
]
father[top]
father[top]上跳,并更新跳过的区间。直到跳到一条重链上,这样就可以直接更新。
2
、
2、
2、同
1
1
1,不过将修改操作换成查询操作。
3
、
3、
3、同一个子树的
d
f
s
dfs
dfs序一定是连续的,所以查询
x
∼
x
+
s
i
z
e
x
−
1
x\sim x+size_x-1
x∼x+sizex−1间的答案。
4
、
4、
4、同
3
3
3,查询操作变成修改操作。
其实操作 1 1 1的跳就相当于求 L C A LCA LCA,只不过这里直接跳完一个链。线段树就基本操作吧。
代码
#include <cctype>
#include <cstdio>
#include <algorithm>
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
char buf[1 << 21], *p1 = buf, *p2 = buf;
struct segmentTree {
int ls, rs, l, r, sum, lazy;
}t[400001];
int father[100001], dep[100001], size[100001], son[100001], rev[100001], top[100001], seg[100001];
int head[100001], ver[200001], next[200001];
int a[100001];
int n, m, r, mod, tot, cnt;
inline long long read() {
long long res = 0, f = 0;
char ch = getchar();
while (!isdigit(ch)) {
if (ch == '-') f = 1;
ch = getchar();
}
while (isdigit(ch)) res = res * 10 + ch - 48, ch = getchar();
return f ? -res : res;
}
void add(int u, int v) {
ver[++tot] = v;
next[tot] = head[u];
head[u] = tot;
}
void dfs1(int u, int fa) {
father[u] = fa;
size[u] = 1;
dep[u] = dep[fa] + 1;
for (int i = head[u]; i; i = next[i]) {
if (ver[i] == fa) continue;
dfs1(ver[i], u);
size[u] += size[ver[i]];
if (size[ver[i]] > size[son[u]]) son[u] = ver[i];
}
}
void dfs2(int u, int t) {
top[u] = t;
seg[u] = ++cnt;
rev[cnt] = u;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = head[u]; i; i = next[i]) {
if (top[ver[i]]) continue;
dfs2(ver[i], ver[i]);
}
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if (l == r) {
t[p].sum = a[rev[l]] % mod;
return;
}
int mid = l + r >> 1;
build(t[p].ls = ++cnt, l, mid);
build(t[p].rs = ++cnt, mid + 1, r);
t[p].sum = (t[t[p].ls].sum + t[t[p].rs].sum) % mod;
}
inline void spread(int p) {
if (!t[p].lazy) return;
t[t[p].ls].lazy = (t[t[p].ls].lazy + t[p].lazy) % mod;
t[t[p].rs].lazy = (t[t[p].rs].lazy + t[p].lazy) % mod;
t[t[p].ls].sum = (t[t[p].ls].sum + (t[t[p].ls].r - t[t[p].ls].l + 1) * t[p].lazy) % mod;
t[t[p].rs].sum = (t[t[p].rs].sum + (t[t[p].rs].r - t[t[p].rs].l + 1) * t[p].lazy) % mod;
t[p].lazy = 0;
}
int query(int p, int l, int r) {
if (l <= t[p].l && t[p].r <= r)
return t[p].sum;
spread(p);
int mid = t[p].l + t[p].r >> 1, res = 0;
if (l <= mid) res = (res + query(t[p].ls, l, r)) % mod;
if (r > mid) res = (res + query(t[p].rs, l, r)) % mod;
return res;
}
inline int ask(int x, int y) {
int fx = top[x], fy = top[y], res = 0;
while (fx != fy) {
if (dep[fx] < dep[fy]) std::swap(x, y), std::swap(fx, fy);
res = (res + query(1, seg[fx], seg[x])) % mod;
x = father[fx];
fx = top[x];
}
if (dep[x] > dep[y]) std::swap(x, y);
res = (res + query(1, seg[x], seg[y])) % mod;
return res;
}
void update(int p, int l, int r, int val) {
if (l <= t[p].l && t[p].r <= r) {
t[p].sum = (t[p].sum + (t[p].r - t[p].l + 1) * val) % mod;
t[p].lazy = (t[p].lazy + val) % mod;
return;
}
spread(p);
int mid = t[p].l + t[p].r >> 1;
if (l <= mid) update(t[p].ls, l, r, val);
if (r > mid) update(t[p].rs, l, r, val);
t[p].sum = (t[t[p].ls].sum + t[t[p].rs].sum) % mod;
}
void modify(int x, int y, int val) {
int fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] < dep[fy]) std::swap(x, y), std::swap(fx, fy);
update(1, seg[fx], seg[x], val);
x = father[fx];
fx = top[x];
}
if (dep[x] > dep[y]) std::swap(x, y);
update(1, seg[x], seg[y], val);
}
signed main() {
n = read(), m = read(), r = read(), mod = read();
for (register int i = 1; i <= n; i++)
a[i] = read();
for (register int i = 1; i < n; i++) {
int x = read(), y = read();
add(x, y), add(y, x);
}
dfs1(r, 0);
dfs2(r, r);
build(cnt = 1, 1, n);
for (int op, x, y, z; m; m--) {
op = read();
if (op == 1) {
x = read(), y = read(), z = read() % mod;
modify(x, y, z);
} else if (op == 2) {
x = read(), y = read();
printf("%d\n", ask(x, y));
} else if (op == 3) {
x = read(), z = read() % mod;
update(1, seg[x], seg[x] + size[x] - 1, z);
} else {
x = read();
printf("%d\n", query(1, seg[x], seg[x] + size[x] - 1));
}
}
}