前置知识:基础求和线段树 + 树上操作 + dfs (废话)
一. 什么是树链剖分?
树链剖分,将一棵树分成多条链然后摆在一起用线段树维护区间和。
二. 基础概念有什么
- 重儿子:一个结点的儿子中子树节点数最大的那一个就是他的重儿子
- 轻儿子:一个结点中不是他的重儿子的儿子都是他的轻儿子
(叶子节点没有轻儿子 or 重儿子) - 重(zhong)边:连接两个重儿子的边
- 轻边:重边除外的边
- 链:树上任两点的简单路径
- 重链:链上的点都为重儿子的链
三 . 需要的属性
- 树映射到数组里某结点的新位置,新位置对应的树上结点的值
- 结点的深度
- 映射的数组的区间和(线段树维护)
- 结点的子树大小
- 结点所在重链的顶端
四 . 查询 or 修改 路径上的值
查询和修改中所作是相同的,就是找路径映射到路径上的点然后修改 or 查询。
使用轻重链的好处:在寻找时我们可以加速。
现在我们要找 x − y x - y x−y 路径映射的区间(们),两点在没到同一链之前,找两点中深的那一个点将它升到它的链的顶端,注意这里的链都是应该修改 o r or or查询的点之一( i d [ x ] − i d [ t o p [ x ] ] id[x]-id[top[x]] id[x]−id[top[x]] )。
当在同一链上时,区间为id[x] 到 id[y]。
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<stack>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<set>
#define int long long
using namespace std;
int mod;
int head[100005];
int dep[100005];//深度
int size[100005], son[100005], id[100005];//该点的子树大小,重儿子,dfs 序的序号
int top[100005], a[100005];// 链的顶端,某结点的初值
int sum[400005], tag[400005];//线段树
int f[100005];//一个子节点的父亲是谁
int w[100005];
int n, m, r, p;
struct Node{
int next, v;
};
int nw[100005];
Node edge[200005];
int tot = 0;
void add_edge(int u, int v){
tot++;
// cout << tot << endl;
edge[tot].next = head[u];
edge[tot].v = v;
head[u] = tot;
}
void dfs1(int u, int fa){
f[u] = fa;
dep[u] = dep[fa] + 1;//深度为自己父亲的深度 + 1
size[u] = 1;//子树大小只有自己时为 1
int maxs = 0;//当前最大子树的儿子的子树的结点个数
for(int i = head[u]; i; i = edge[i].next){
int v = edge[i].v;
if(v == fa){
continue;
}
dfs1(v, u);
size[u] += size[v];//自己子树的大小加上儿子的
if(size[v] > maxs){
son[u] = v;
maxs = size[v];
}
}
return;
}
void dfs2(int u, int topp){
tot++;
id[u] = tot;// dfs 序号
w[tot] = a[u];// dfs 序的第 tot 值为树上 u 结点的值
top[u] = topp;// top[u] 所在链的顶端
if(son[u] == 0){//没有重儿子那么它为叶子节点
return;
}
dfs2(son[u], topp);// u 的重儿子所在链就是 u 所在链
for(int i = head[u]; i; i = edge[i].next){
int v = edge[i].v;
if(v == f[u] || v == son[u]){//如果是轻儿子
continue;
}
dfs2(v, v);//如果是轻儿子链顶是自己
}
}
//以上是预处理
int lc(int p){
return (p * 2);
}
int rc(int p){
return (p * 2 + 1);
}
void push_up(int p){
sum[p] = (sum[lc(p)] + sum[rc(p)]) % mod;
}
void push_down(int p, int l, int r){
int mid = (l + r) / 2;
sum[lc(p)] = (sum[lc(p)] + (mid - l + 1) * tag[p] % mod) % mod;
sum[rc(p)] = (sum[rc(p)] + (r - mid) * tag[p] % mod) % mod;
tag[lc(p)] = (tag[lc(p)] + tag[p]) % mod;
tag[rc(p)] = (tag[p] + tag[rc(p)]) % mod;
tag[p] = 0;
return;
}
int query(int p, int l, int r, int ql, int qr){
if(l > qr || r < ql){
return 0;
}
if(l >= ql && r <= qr){
return sum[p];
}
push_down(p, l, r);
int mid = (l + r) / 2;
return (query(lc(p), l, mid, ql, qr) + query(rc(p), mid + 1, r, ql, qr)) % mod;
}
void update(int p, int l, int r, int ql, int qr, int t){
if(l > qr || r < ql){
return;
}
if(l >= ql && r <= qr){
sum[p] = (sum[p] + t * (r - l + 1) % mod) % mod;
tag[p] = (tag[p] + t) % mod;
return;
}
int mid = (l + r) / 2;
push_down(p, l, r);
update(lc(p), l, mid, ql, qr, t);
update(rc(p), mid + 1, r, ql, qr, t);
push_up(p);
return;
}
void build(int p, int l, int r){
if(l == r){
sum[p] = w[l];
return;
}
int mid = (l + r) / 2;
build(lc(p), l, mid);
build(rc(p), mid + 1, r);
push_up(p);
}
//以上是线段树
int qtrack(int x, int y){
int ans = 0;
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);//找深度更深的结点去跳到链顶
ans = (ans + query(1, 1, n, id[top[x]], id[x])) % mod;
x = f[top[x]];//跳到链顶之后再向上跳一格
}
if(dep[x] > dep[y]){
swap(x, y);
}
// cout << "end" << id[x] << " " << id[y] << endl;
ans = (ans + query(1, 1, n, id[x], id[y])) % mod;
return ans;
}
void track_up(int x, int y, int t){//修改与查询无二
while(top[x] != top[y]){
if(dep[top[x]] < dep[top[y]]) swap(x, y);
update(1, 1, n, id[top[x]], id[x], t);
x = f[top[x]];
}
if(dep[x] > dep[y]){
swap(x, y);
}
update(1, 1, n, id[x], id[y], t);
return;
}
signed main(){
cin >> n >> m >> r >> mod;
for(int i = 1; i <= n; i++){
cin >> a[i];
}
for(int i = 1; i <= n - 1; i++){
int u, v;
cin >> u >> v;
add_edge(u, v);
add_edge(v, u);
}
dfs1(r, 0);
tot = 0;
dfs2(r, r);
build(1, 1, n);
for(int i = 0; i < m; i++){
int op;
cin >> op;
if(op == 1){
int x, y, z;
cin >> x >> y >> z;
track_up(x, y, z);
} else if(op == 2){
int x, y;
cin >> x >> y;
cout << qtrack(x, y) << endl;
} else if(op == 3){
int x, z;
cin >> x >> z;
int l = id[x];
int r = id[x] + size[x] - 1;
update(1, 1, n, l, r, z);
} else if(op == 4){
int x;
cin >> x;
int l = id[x];
int r = id[x] + size[x] - 1;
cout << query(1, 1, n, l, r) << endl;
}
}
return 0;
}