终于学到了树剖!!!
前置知识:LCA,树形DP,DFS,邻接表,线段树。
树链剖分
线段树的特点:区间修改,区间查询,线性;
树上差分特点:单点修改,树形区间查询;
现在,如果我们想进行树形区间修改和查询,是否存在一种算法能够做到呢?
搜狗百科对于树剖的定义:树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、SBT、SPLAY、线段树等)来维护每一条链。
轻重树链剖分(启发式剖分)
首先明确一些概念:
重儿子:父亲节点的所有儿子中已那个儿子为根的子树节点数目最多(size最大)的节点;
轻儿子:对于每一个非叶子结点,它的儿子中的非重儿子节点
叶子结点既没有重儿子也没有轻儿子
重边:父亲节点和重儿子所连的边
轻边:剩下的边即为轻边
重链:相邻重边连起来的 连接一条重儿子的链叫重链
对于叶子节点,若其为轻儿子,则有一条以自己为起点的长度为1的链
每一条重链以轻儿子为起点
dfs1()
这个dfs要处理几件事情:
标记每个点的深度dep[N]
标记每个点的父亲fa[N]
标记每个非叶子结点的子树大小siz[N]
标记每个非叶子结点的重儿子编号son[N]
void dfs1(int u,int Fa){
dep[u]=dep[Fa]+1;
fa[u]=Fa;
siz[u]=1;
int maxson=-1;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==Fa) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>maxson) son[u]=v,maxson=siz[v];
}
}
dfs2()
这个dfs2也要处理几件事情
标记每个点的新编号id[N]
赋值每个点的初始值到新编号上wt[N]
处理每个点所在链的顶端
处理每条链
顺序:先处理重儿子在处理轻儿子
void dfs2(int u,int topf){
id[u]=++cnt;
wt[cnt]=w[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
由于处理顺序是先处理重儿子再处理轻儿子,所以
每一条重链的新编号是连续的
因为是dfs,所以每一个子树的新编号也是连续的
现在回顾一下我们要处理的问题
处理任意两点间的路径和
处理一点及其子树的点权和
修改任意两点间路径上的点权
修改一点及其子树的点权
1、当我们处理任意两点间路径时:
对于点 u,v,u所在链的顶点深度 ≤ v所在链的顶点深度
ans+=u到u所在链顶端 这一段区间的点权和
把u调到u所在链顶端的点的父亲节点
重复执行以上步骤,直到两个点位于同一条链,此时再加上这两点间的区间和即可
int qRange(int u,int v){
int ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
res=0;
query(1,id[top[u]],id[u]);
ans+=res;
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
res=0;
query(1,id[u],id[v]);
ans+=res;
return ans%mod;
}
2、处理一点及其子树的点权和:
每个非叶子结点,子树的编号都是连续的,线段树区间查询即可
int qSon(int u){
res=0;
query(1,id[u],id[u]+siz[u]-1);
return res%mod;
}
当然,区间修改也一样
void updRange(int u,int v,int k){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(1,id[u],id[v],k);
}
void updSon(int u,int k){
update(1,id[u],id[u]+siz[u]-1,k);
}
建树
这里我们采用线段树维护链,按照题意建树即可。
代码
#include<bits/stdc++.h>
#define int long long
#define ls(r) r<<1
#define rs(r) r<<1|1
usingnamespace std;
constint N=1e6+10;
int n,q,tot,cnt,root,res,mod;
int h[N],w[N],wt[N];
int son[N],id[N],fa[N],dep[N],siz[N],top[N];
struct node{
int to,nxt,w;
}e[N<<1];
struct tree{
int l,r,ans,tag;
}t[N<<2];
void add(int u,int v){
e[++tot].to=v;
e[tot].nxt=h[u];
h[u]=tot;
}
void f(int p,int l,int r,int k){
t[p].tag+=k;
t[p].ans+=(r-l+1)*k;
t[p].ans%=mod;
}
void push_down(int p){
int mid=t[p].l+t[p].r>>1;
f(ls(p),t[p].l,mid,t[p].tag);
f(rs(p),mid+1,t[p].r,t[p].tag);
t[p].tag=0;
}
void push_up(int p){
t[p].ans=t[ls(p)].ans+t[rs(p)].ans;
}
void build(int l,int r,int p){
t[p].l=l,t[p].r=r;
if(l==r){
t[p].ans=wt[l];return;
}
int mid=l+r>>1;
build(l,mid,ls(p));
build(mid+1,r,rs(p));
push_up(p);
}
void query(int p,int l,int r){
if(l<=t[p].l&&t[p].r<=r){ res+=t[p].ans;return; }
if(t[p].tag) push_down(p);
int mid=t[p].l+t[p].r>>1;
if(l<=mid) query(ls(p),l,r);
if(r>mid) query(rs(p),l,r);
}
void update(int p,int l,int r,int k){
if(l<=t[p].l&&t[p].r<=r){
t[p].tag+=k;t[p].ans+=k*(t[p].r-t[p].l+1);return ;
}
push_down(p);
int mid=t[p].l+t[p].r>>1;
if(l<=mid) update(ls(p),l,r,k);
if(r>mid) update(rs(p),l,r,k);
push_up(p);
}
int qRange(int u,int v){
int ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
res=0;
query(1,id[top[u]],id[u]);
ans+=res;
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
res=0;
query(1,id[u],id[v]);
ans+=res;
return ans%mod;
}
void updRange(int u,int v,int k){
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,id[top[u]],id[u],k);
u=fa[top[u]];
}
if(dep[u]>dep[v]) swap(u,v);
update(1,id[u],id[v],k);
}
int qSon(int u){
res=0;
query(1,id[u],id[u]+siz[u]-1);
return res%mod;
}
void updSon(int u,int k){
update(1,id[u],id[u]+siz[u]-1,k);
}
void dfs1(int u,int Fa){
dep[u]=dep[Fa]+1;
fa[u]=Fa;
siz[u]=1;
int maxson=-1;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==Fa) continue;
dfs1(v,u);
siz[u]+=siz[v];
if(siz[v]>maxson) son[u]=v,maxson=siz[v];
}
}
void dfs2(int u,int topf){
id[u]=++cnt;
wt[cnt]=w[u];
top[u]=topf;
if(!son[u]) return;
dfs2(son[u],topf);
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].to;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
signed main(){
cin>>n>>q>>root>>mod;
for(int i=1;i<=n;++i) cin>>w[i];
for(int i=1,u,v;i<n;++i){
cin>>u>>v;add(u,v),add(v,u);
}
dfs1(root,0);
dfs2(root,root);
build(1,n,1);
while(q--){
int k,u,v,x;
cin>>k;
if(k==1){
cin>>u>>v>>x;
updRange(u,v,x);
}
elseif(k==2){
cin>>u>>v;
cout<<qRange(u,v)%mod<<endl;
}
elseif(k==3){
cin>>u>>v;
updSon(u,v);
}
else{
cin>>u;
cout<<qSon(u)%mod<<endl;
}
}
return0;
}