树链剖分:
首先给出几个定义:
重儿子:对于每个非叶子节点来说,它的儿子中,重儿子作为根节点的子树大小最大。
轻儿子:对于每个非叶子节点来说,它的儿子中除了重儿子以外的儿子全是轻儿子。
显然叶子节点没有重儿子也没有轻儿子。
重链:从上至下,以轻儿子为起点,其余全是重儿子的链。
接下来需要进行两边 d f s dfs dfs ,求出树链剖分所需要的元素:
d
f
s
1
dfs1
dfs1:
f
a
[
v
]
fa[v]
fa[v]:节点
v
v
v 的父节点。若
v
v
v 为根则为
0
0
0。
w
s
o
n
[
v
]
wson[v]
wson[v]:节点
v
v
v 的重儿子。若
v
v
v 为叶子节点则为
0
0
0。
d
e
p
t
h
[
v
]
depth[v]
depth[v]:节点
v
v
v 在树中的深度。
s
i
z
[
v
]
siz[v]
siz[v]:节点
v
v
v 作为根的子树大小。
void dfs1(const int &v)
{
depth[v] = depth[fa[v]] + 1, siz[v] = 1;
for (int i = head[v]; i; i = e[i].next)
{
int u = e[i].to;
if (siz[u])
continue;
fa[u] = v;
dfs1(u);
siz[v] += siz[u];
if (siz[wson[v]] < siz[u])//选择最大的子树的根为重儿子
wson[v] = u;
}
}
d
f
s
2
dfs2
dfs2:
p
s
ps
ps:这次
d
f
s
dfs
dfs 对于节点
v
v
v ,对于其儿子的遍历要优先遍历重儿子,再遍历轻儿子,这样做保证重链上的点编号
i
d
id
id 为连续的。
i
d
[
v
]
id[v]
id[v]:
v
v
v 节点的
d
f
s
dfs
dfs 序
t
o
p
[
v
]
top[v]
top[v]:
v
v
v 节点所在的重链上的最上面的轻儿子节点。
s
[
i
d
[
v
]
]
=
v
s[id[v]] = v
s[id[v]]=v:建立节点与其
i
d
id
id 的联系。
void dfs2(const int &v, const int &topv)
{
id[v] = ++cnt;
s[cnt] = w[v], top[v] = topv;
if(wson[v] == 0)
return;
dfs2(wson[v], topv);//先遍历重儿子
for (int i = head[v]; i; i = e[i].next)
{
int u = e[i].to;
if(id[u])
continue;
dfs2(u, u);
}
}
对于此题一共有四种操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
首先对于操作3,4,由于之前进行 d f s 2 dfs2 dfs2 时记录下了 d f s dfs dfs 序,那么,对于一个节点 v v v ,以及其子树, i d id id 值均在 [ i d [ v ] , i d [ v ] + s i z [ v ] − 1 ] [id[v],id[v]+siz[v]-1] [id[v],id[v]+siz[v]−1] 这个区间,那么操作3,4即为对这个区间进行区间修改以及区间求和即可。
操作1,2:最短路径很明显就是求 L C A LCA LCA ,这里我们模仿倍增法求 L C A LCA LCA 的思路:
步骤1. 对于节点
x
,
y
x,y
x,y,选择其
d
e
p
t
h
[
t
o
p
[
]
]
depth[top[]]
depth[top[]](即该节点所在重链的最上面的轻儿子的深度)较大的那个节点
v
v
v ,操作该轻儿子的
i
d
id
id 到节点
v
v
v 的
i
d
id
id 这段区间(因为对于两条重链,
L
C
A
LCA
LCA 不可能在深度更深的那条重链上,所以,深度更深的那条重链的顶部到节点
v
v
v 上的点均位于
x
,
y
x,y
x,y 的树上最短路径) ,然后使
v
=
f
a
[
t
o
p
[
v
]
]
v = fa[top[v]]
v=fa[top[v]] 。
重复进行这一个步骤,直到两者的
d
e
p
t
h
[
t
o
p
[
]
]
depth[top[]]
depth[top[]] 相等。
步骤2. 此时, L C A LCA LCA 显然是在更新完毕之后的 x , y x,y x,y 之间,而 x , y x,y x,y 位于同一条重链上,所以只需要操作深度小的 i d id id 到深度大的 i d id id 这段区间即可。
if (op == 1 || op == 2)
{
read(y);
if (op == 1)
read(z), z %= p;
int ans = 0;
while (top[x] != top[y])
{
if (depth[top[x]] < depth[top[y]])
swap(x, y);
if (op == 1)
update(1, id[top[x]], id[x], z);
else
(ans += query(1, id[top[x]], id[x])) %= p;
x = fa[top[x]];
}
if (depth[x] > depth[y])
swap(x, y);
if (op == 1)
update(1, id[x], id[y], z);
else
(ans += query(1, id[x], id[y])) %= p, write(ans), putchar('\n');
}
到这里操作1,2整个就进行完毕。
这里区间的操作用线段树完成即可。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define lowbit(x) (x & (-x))
#define ffor(i,d,u) for(int i=(d);i<=(u);++i)
#define _ffor(i,u,d) for(int i=(u);i>=(d);--i)
#define INF 0x3f3f3f3f
#define mst(arrary, count, kind, num) memset(arrary, num, sizeof(kind) * (count))
const ll LLINF = 0x3f3f3f3f3f3f3f3f;
template <typename T>
void read(T& x)
{
x=0;
char c;T t=1;
while(((c=getchar())<'0'||c>'9')&&c!='-');
if(c=='-'){t=-1;c=getchar();}
do(x*=10)+=(c-'0');while((c=getchar())>='0'&&c<='9');
x*=t;
}
template <typename T>
void write(T x)
{
int len=0;char c[21];
if(x<0)putchar('-'),x*=(-1);
do{++len;c[len]=(x%10)+'0';}while(x/=10);
_ffor(i,len,1)putchar(c[i]);
}
const int mod = 1e9 + 7;
const int NUM = 1e5 + 5;
const int MAXN = 4e5 + 5;
int n, m, r, p;
int head[NUM] = {}, ednum = 0;
int fa[NUM], wson[NUM] = {}, depth[NUM], siz[NUM] = {};
int cnt = 0, id[NUM] = {}, top[NUM];
int s[NUM], w[NUM]; //点权
struct edge
{
int to, next;
} e[NUM * 2];
inline void addedge(const int &v, const int &u)
{
e[++ednum] = edge{u, head[v]}, head[v] = ednum;
}
void dfs1(const int &v)
{
depth[v] = depth[fa[v]] + 1, siz[v] = 1;
for (int i = head[v]; i; i = e[i].next)
{
int u = e[i].to;
if (siz[u])
continue;
fa[u] = v;
dfs1(u);
siz[v] += siz[u];
if (siz[wson[v]] < siz[u])
wson[v] = u;
}
}
void dfs2(const int &v, const int &topv)
{
id[v] = ++cnt;
s[cnt] = w[v], top[v] = topv;
if(wson[v] == 0)
return;
dfs2(wson[v], topv);
for (int i = head[v]; i; i = e[i].next)
{
int u = e[i].to;
if(id[u])
continue;
dfs2(u, u);
}
}
struct linetree
{
int l, r, mid;
int sum, lazy, len;
} t[MAXN];
void build(int o, int l, int r)
{
t[o].l = l, t[o].r = r, t[o].lazy = 0, t[o].len = (r - l + 1) % p;
if (l == r)
{
t[o].sum = s[l] % p;
return;
}
t[o].mid = (l + r) >> 1;
build(o << 1, l, t[o].mid), build(o << 1 | 1, t[o].mid + 1, r);
t[o].sum = (t[o << 1].sum + t[o << 1 | 1].sum) % p;
}
inline void pushdown(int o)
{
if(t[o].lazy)
{
(t[o << 1].lazy += t[o].lazy) %= p, (t[o << 1 | 1].lazy += t[o].lazy) %= p;
t[o].lazy = 0;
}
}
inline int getsum(const int &o)
{
return (t[o].lazy * t[o].len % p + t[o].sum) % p;
}
inline void pushup(int o)
{
t[o].sum = (getsum(o << 1) + getsum(o << 1 | 1)) % p;
}
void update(int o, int l, int r, int z)
{
if (t[o].l == l && t[o].r == r)
{
(t[o].lazy += z) %= p;
return;
}
pushdown(o);
if (t[o].mid < l)
update(o << 1 | 1, l, r, z);
else
{
if(t[o].mid >= r)
update(o << 1, l, r, z);
else
update(o << 1, l, t[o].mid, z), update(o << 1 | 1, t[o].mid + 1, r, z);
}
pushup(o);
}
int query(int o, int l, int r)
{
if (t[o].l == l && t[o].r == r)
return getsum(o);
pushdown(o);
int ans = 0;
if (t[o].mid < l)
ans = query(o << 1 | 1, l, r);
else
{
if(t[o].mid >= r)
ans = query(o << 1, l, r);
else
ans = (query(o << 1, l, t[o].mid) + query(o << 1 | 1, t[o].mid + 1, r)) % p;
}
pushup(o);
return ans;
}
inline void ac()
{
int u, v, op, x, y, z;
read(n), read(m), read(r), read(p);
ffor(i, 1, n) read(w[i]);
ffor(i, 2, n)
{
read(u), read(v);
addedge(u, v), addedge(v, u);
}
fa[r] = 0, dfs1(r), dfs2(r, r), build(1, 1, n);
while (m--)
{
read(op), read(x);
if (op == 1 || op == 2)
{
read(y);
if (op == 1)
read(z), z %= p;
int ans = 0;
while (top[x] != top[y])
{
if(depth[top[x]] < depth[top[y]])
swap(x, y);
if (op == 1)
update(1, id[top[x]], id[x], z);
else
(ans += query(1, id[top[x]], id[x])) %= p;
x = fa[top[x]];
}
if (depth[x] > depth[y])
swap(x, y);
if (op == 1)
update(1, id[x], id[y], z);
else
(ans += query(1, id[x], id[y])) %= p, write(ans), putchar('\n');
}
else if (op == 3)
{
read(z), z %= p;
update(1, id[x], id[x] + siz[x] - 1, z);
}
else
write(query(1, id[x], id[x] + siz[x] - 1)), putchar('\n');
}
}
int main()
{
ac();
return 0;
}