理解
树链剖分和DFS序很像,都是将树形转化为线形然后用线段树维护。树链剖分和DFS序的唯一不同就在与走的方向不同:DFS序按照顺序走子结点,树链剖分先走重儿子。其他基本一致。
有几个问题:求一个结点的子树的所有值。DFS序是连续的,所以这里体现不出来树链剖分的优势。
求两个结点路径上的和。这里可以用LCA解决,用dis数组记录结点到根的距离,就可以dis [ x ] + dis [ y ] - 2 * dis [ lca ] ,就可以得到两点的距离。但是如果加上修改操作,每次都用再计算一遍dis数组,复杂度就爆炸。这里就用到了树链剖分!
DFS序在线段树中其实是一块一块的,在树上连续的链,在线段树上才有意义。或者说我们可以把树分割成一条条的链,在线段树上一个小块维护树的一条链。但是DFS序的链一般都很短,或者说链的长度是按照DFS序决定的,我们想让链尽可能的长,这样长链在线段树上对应的标号是连续的。
这样做会导致链的个数最少(logN个),链数少就会省时间(试想有N条链,完全没有优化)。因为在查询两点路径的走法,如果两点不在一条重链上, ans加上x点到x所在链顶端;如果两点在一条重链上,加上此时两个点的区间和。
理解两个DFS & 图好评:https://www.cnblogs.com/ivanovcraft/p/9019090.html
代码有详细注释:https://www.cnblogs.com/chinhhh/p/7965433.html
#include <bits/stdc++.h>
using namespace std;
int mod;
const int N=2e5+10;
struct Edge{
int next,to;
}edge[2*N];
int head[N],tot;
void addEdge(int from,int to)
{
edge[tot].to = to; edge[tot].next = head[from];
head[from] = tot++;
}
int cnt,f[N],d[N],siz[N],son[N],id[N],rk[N],top[N];
void dfs1(int u,int fa,int dep)
{
f[u] = fa;
d[u] = dep;
siz[u] = 1;
for(int i=head[u];i!=-1;i=edge[i].next){
int v = edge[i].to;
if(v==fa) continue;
dfs1(v,u,dep+1);
siz[u] += siz[v];
if(siz[v]>siz[son[u]]) son[u] = v;
}
}
void dfs2(int u,int t)
{
top[u] = t;
id[u] = ++cnt;
rk[cnt] = u;
if(!son[u]) return;
dfs2(son[u],t);
for(int i=head[u];i!=-1;i=edge[i].next){
int v = edge[i].to;
if(v!=son[u]&&v!=f[u]) dfs2(v,v);
}
}
struct{
int l,r;
ll w,lazy;
}tr[N*4];
int a[N];
void build(int k,int l,int r)
{
tr[k].l=l; tr[k].r=r; tr[k].lazy=0;
if(l==r){
tr[k].w = a[rk[l]]%mod;
return;
}
int mid = (l+r)/2;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
tr[k].w = (tr[k*2].w + tr[k*2+1].w)%mod;
}
void pushdown(int k)
{
if(tr[k].lazy){
tr[k*2].lazy += tr[k].lazy;
tr[k*2+1].lazy += tr[k].lazy;
tr[k*2].lazy %= mod;
tr[k*2].lazy %= mod;
tr[k*2].w += (tr[k*2].r-tr[k*2].l+1)*tr[k].lazy;
tr[k*2+1].w += (tr[k*2+1].r-tr[k*2+1].l+1)*tr[k].lazy;
tr[k*2].w %= mod;
tr[k*2+1].w %= mod;
tr[k].lazy = 0;
}
}
void update(int k,int l,int r,int v)
{
v %= mod;
if(l>tr[k].r||r<tr[k].l) return;
if(l<=tr[k].l&&tr[k].r<=r){
tr[k].lazy += v;
tr[k].lazy %= mod;
tr[k].w +=(tr[k].r-tr[k].l+1)*v%mod;
return;
}
pushdown(k);
update(k*2,l,r,v);
update(k*2+1,l,r,v);
tr[k].w = (tr[k*2].w + tr[k*2+1].w)%mod;
}
int query(int k,int l,int r)
{
if(l>tr[k].r||r<tr[k].l) return 0;
pushdown(k);
if(l<=tr[k].l&&tr[k].r<=r) return tr[k].w%mod;
return (query(k*2,l,r)+query(k*2+1,l,r))%mod;
}
void uprange(int x,int y,int v)
{
v %= mod;
while (top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x, y);
update(1,id[top[x]],id[x],v);
x = f[top[x]];
}
if(d[x]>d[y]) swap(x, y);
update(1,id[x],id[y],v);
}
int qrange(int x,int y)
{
int ans = 0;
while(top[x]!=top[y])
{
if(d[top[x]]<d[top[y]]) swap(x, y);
ans += query(1,id[top[x]],id[x]);
ans %= mod;
x = f[top[x]];
}
if (d[x]>d[y]) swap(x, y);
ans += query(1,id[x],id[y]);
return ans%mod;
}
void upson(int k,int v)
{
update(1,id[k],id[k]+siz[k]-1,v);
}
int qson(int k)
{
int res = query(1,id[k],id[k]+siz[k]-1);
return res%mod;
}
int main()
{
tot = 0;
memset(head,-1,sizeof(head));
int n,m,root,p;
scanf("%d%d%d%d",&n,&m,&root,&mod);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
for(int i=1;i<n;i++){
int x,y; scanf("%d%d",&x,&y);
addEdge(x,y);
addEdge(y,x);
}
dfs1(root,0,1);
dfs2(root,root);
build(1,1,n);
int op,x,y,z;
while(m--){
scanf("%d",&op);
if(op==1){
scanf("%d%d%d",&x,&y,&z);
uprange(x,y,z);
}
else if(op==2){
scanf("%d%d",&x,&y);
printf("%d\n",qrange(x,y));
}
else if(op==3){
scanf("%d%d",&x,&z);
upson(x,z);
}
else if(op==4){
scanf("%d",&x);
printf("%d\n",qson(x));
}
}
return 0;
}