树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。
首先就是一些必须知道的概念:
重节点:子树结点数目最多的结点;
轻节点:父亲节点中除了重结点以外的结点;
重边:父亲结点和重结点连成的边;
轻边:父亲节点和轻节点连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
我们通过维护这些树链,就可以完成对树的区间处理。
我们先定义以下数组:
son[x]:点x的重儿子节点; siz[x]:点x的子树大小;
fa[x]:点x的父亲节点; dep[x]:点x的深度;
top[x]:点x所在链的最上面的节点; tid[x]:点x的dfs序;
rank[x]:x在树中相应的值(也可以是位置)//这是线段树套树剖时用的
首先,我们可以用一个dfs求出每个点的siz,dep,fa,son。
void dfsI(int x,int father){
fa[x]=father;son[x]=-1;
dep[x]=dep[fa[x]]+1;siz[x]=1;
for(int i=head[x];i!=-1;i=nxt[i]){
int go=to[i];
if(go==fa[x]) continue;
dfsI(go,x);siz[x]+=siz[go];
if(son[x]==-1||siz[go]>siz[son[x]]) son[x]=go;
}
}
这是个非常简单的dfs,相信大家都会。
接下去,我们要求出每一条链,我们以参数t记录点x的链的链顶。如果点x为重节点,则dfs(x,t),继续以t为链顶。如果x为轻节点,则dfs(x,x),即轻链的top为自己,这样就可以用top维护出重链和轻链。
void dfsII(int x,int t){
tid[x]=++cnt;
rank[cnt]=val[x]%q;
top[x]=t;
if(son[x]==-1) return;
dfsII(son[x],t);
for(int i=head[x];i!=-1;i=nxt[i]){
int go=to[i];
if(go==fa[x]||go==son[x]) continue;
dfsII(go,go);
}
}
树剖以三个dfs闻名,现在已经两个了,还有一个。用来求从u到v的树链上的所有点权值之和(也可以更新从u到v的树链上的点的值)。我们运用一个类似LCA的思想,如果top[u]!=top[v],就可以将其中的某一个往上跳。(只能跳一个,否则可能会两个点擦肩而过,导致程序WA),当top[u]=top[v]时,直接用线段树求解之间的区间和即可。
//里面的update和query就是线段树的操作,到时候在下面的程序里会显现
int queryLink(int x,int y){ //树链求和
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ans+=query(1,1,n,tid[top[y]],tid[y]);
ans%=q;
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(1,1,n,tid[x],tid[y]);
ans%=q;
return ans;
}
void updateLink(int x,int y,int ad){ //树链修改
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
update(1,1,n,tid[top[y]],tid[y],ad);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,tid[x],tid[y],ad);
}
这是洛谷上一道树链剖分的模板题,我们可以用线段树维护树剖解决。
我们就用线段树维护树剖即可。
Code:
#include<bits/stdc++.h>
#define MAXN 100005
using namespace std;
int read(){
char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
int n,m,r,q,cnt,rec;
int son[MAXN],rank[MAXN],tid[MAXN],nxt[MAXN<<1],head[MAXN],top[MAXN],fa[MAXN];
int to[MAXN<<1],val[MAXN],sum[MAXN<<2],add[MAXN<<2],dep[MAXN],siz[MAXN];
void addedge(int x,int y){
to[rec]=y;nxt[rec]=head[x];head[x]=rec;rec++;
to[rec]=x;nxt[rec]=head[y];head[y]=rec;rec++;
}
void up(int node){sum[node]=(sum[node<<1]+sum[node<<1|1])%q;}
void down(int node,int ls,int rs){ //下推标记
if(add[node]){
add[node<<1]+=add[node];
add[node<<1|1]+=add[node];
sum[node<<1]+=add[node]*ls;
sum[node<<1|1]+=add[node]*rs;
sum[node<<1]%=q;sum[node<<1|1]%=q;
add[node]=0;
}
}
void dfsI(int x,int father){ //第一个DFS
fa[x]=father;son[x]=-1;
dep[x]=dep[fa[x]]+1;siz[x]=1;
for(int i=head[x];i!=-1;i=nxt[i]){
int go=to[i];
if(go==fa[x]) continue;
dfsI(go,x);siz[x]+=siz[go];
if(son[x]==-1||siz[go]>siz[son[x]]) son[x]=go;
}
}
void dfsII(int x,int t){ //第二个DFS
tid[x]=++cnt;
rank[cnt]=val[x]%q;
top[x]=t;
if(son[x]==-1) return;
dfsII(son[x],t);
for(int i=head[x];i!=-1;i=nxt[i]){
int go=to[i];
if(go==fa[x]||go==son[x]) continue;
dfsII(go,go);
}
}
void build(int node,int l,int r){ //建树
if(l==r){
sum[node]=rank[l];
return;
}
int mid=(l+r)>>1;
build(node<<1,l,mid);
build(node<<1|1,mid+1,r);
up(node);
}
void update(int node,int l,int r,int L,int R,int ad){ //线段树区间修改
if(L<=l&&r<=R){
sum[node]+=ad*(r-l+1);sum[node]%=q;
add[node]+=ad;return;
}
int mid=(l+r)>>1;
down(node,mid-l+1,r-mid);
if(L<=mid) update(node<<1,l,mid,L,R,ad);
if(R>mid) update(node<<1|1,mid+1,r,L,R,ad);
up(node);
}
int query(int node,int l,int r,int L,int R){ //区间求和
if(L<=l&&r<=R){
return sum[node];
}
int mid=(l+r)>>1,ans=0;
down(node,mid-l+1,r-mid);
if(L<=mid) ans+=query(node<<1,l,mid,L,R);
if(R>mid) ans+=query(node<<1|1,mid+1,r,L,R);
ans%=q;
return ans;
}
int qTree(int x){ //这是操作4
return query(1,1,n,tid[x],tid[x]+siz[x]-1);
}
void upTree(int x,int ad){ //操作3
update(1,1,n,tid[x],tid[x]+siz[x]-1,ad);
}
int queryLink(int x,int y){ //DFS求链上权值和
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
ans+=query(1,1,n,tid[top[y]],tid[y]);
ans%=q;
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(1,1,n,tid[x],tid[y]);
ans%=q;
return ans;
}
void updateLink(int x,int y,int ad){ //链上修改
while(top[x]!=top[y]){
if(dep[top[x]]>dep[top[y]]) swap(x,y);
update(1,1,n,tid[top[y]],tid[y],ad);
y=fa[top[y]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,tid[x],tid[y],ad);
}
int main()
{
memset(head,-1,sizeof(head));
n=read();m=read();r=read();q=read();
for(int i=1;i<=n;i++) val[i]=read();
for(int i=1;i<=n-1;i++){
int x=read(),y=read();
addedge(x,y);
}
dfsI(r,0);
dfsII(r,r);
build(1,1,n);
for(int i=1;i<=m;i++){
int x=read();
if(x==1){
int u=read(),v=read(),ad=read();
updateLink(u,v,ad);
}
if(x==2){
int u=read(),v=read();
printf("%d\n",queryLink(u,v));
}
if(x==3){
int u=read(),ad=read();
upTree(u,ad);
}
if(x==4){
int u=read();
printf("%d\n",qTree(u));
}
}
return 0;
}