树链剖分主要解决在树上某条路径上或某棵子树的sum与最值
入门树链剖分,最重要的概念是重儿子,用son[]记录。son[i]代表的是以i为根,节点最多子树根的编号。通过son,我们将树中的边分为两种,轻边和重边。重边是每一个点与它重儿子的连边。将链续的重边串起来,构成了一条条重链。对于每个点,它一定在某条重链上,特殊的,一个单独的点也可以是一条重链。对于一条重链上的点,他们的dfs序是连续的,对与一颗子树上的点,他们的dfs序也是连续的,于是我们将树上的点转化成一个区间,在区间用线段树上求解或修改。
树链剖分的核心便是如何将树剖分成若干链。
在此之前,先了解以下七个参数的意义。
fa[x] x的父亲节点编号
dep[x] x的深度
size[x] 以x为根子树的节点,用来求son
seg[x] 以son为基础的dfs2序,即其在线段树上的编号
rev[x] 用来将线段树上的编号转化为原编号。
son[x] 记录重儿子
top[x] 记录x所在重链dep最小的点
树链剖分需要的前置知识有DFS+LCA+线段树
我们在两次dfs中求出上述七个参数
inline void dfs1(int u,int f){ int i,v; size[u]=1; fa[u]=f; dep[u]=dep[f]+1; for(i=fir[u];v=to[i],i;i=nex[i]){ if(v!=f){ dfs1(v,u); size[u]+=size[v]; if(size[v]>size[son[u]])//更新重儿子 son[u]=v; } } }
dfs2:求出rev,seg,top。
inline void dfs2(int u,int f){ int i,v; if(son[u]){//优先遍历重儿子。 seg[son[u]]=++tim; rev[tim]=son[u]; top[son[u]]=top[u];//重儿子的top,就是u的top。 dfs2(son[u],u); } for(i=fir[u];v=to[i],i;i=nex[i]) if(!top[v]){//访问轻边 seg[v]=++tim; rev[tim]=v; top[v]=v;//轻边单独开了一条链,top是本身 dfs2(v,u); } }
两遍dfs就把整棵树划分为若干条链,剩下的就交给线段树解决了。
首先是建树
inline void build(int k,int l,int r){ if(l==r){ sum[k]=ma[k]=w[l];//w是每个点的全值 return; } int mid=l+r>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); ma[k]=max(ma[k<<1],ma[k<<1|1]); sum[k]=sum[k<<1]+sum[k<<1|1]; }
线段树的查询和修改类似。
inline void change(int k,int l,int r,int val,int pos){
//pos是当前要修改的的位置,val是改变后的值 if(l>pos||r<pos) return; if(l==r&&l==pos){ ma[k]=sum[k]=val; return; } int mid=l+r>>1; change(k<<1,l,mid,val,pos); change(k<<1|1,mid+1,r,val,pos); sum[k]=sum[k<<1]+sum[k<<1|1]; ma[k]=max(ma[k<<1],ma[k<<1|1]); } inline void query(int k,int x,int y,int l,int r){
//l~r为需要修改的区间 if(x>r||y<l) return; if(x>=l&&y<=r){ SUM+=sum[k]; MAX=max(ma[k],MAX); return; } int mid=x+y>>1; query(k<<1,x,mid,l,r); query(k<<1|1,mid+1,y,l,r); }
最后,我们只需要知道只需要知道哪些点对这条路径有贡献,统计他们的贡献即可。
inline void ask(int x,int y){
inline void ask(int x,int y){
int fx=top[x],fy=top[y]; while(fx!=fy){//如果他们不在同一重链上 if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy);//选取深度大的那一条, query(1,1,tim,seg[fx],seg[x]);//注意要将原编号转化为dfs序编号 x=fa[x],fx=top[x]; }
//如果他们在一条链上了,再统计x~y路径的贡献 if(dep[x]>dep[y]) swap(x,y);//保证x的编号小等于y query(1,1,tim,seg[x],seg[y]); }
下面附上一道模板题
https://www.lydsy.com/JudgeOnline/problem.php?id=1036
#include<cstdio> #include<iostream> #include<cstring> #define max(x,y) (x>y?x:y) #define N 100000 using namespace std; int n,m,tot,tim,SUM,MAX; int fir[N],to[N],nex[N]; int seg[N],rev[N],size[N],son[N],dep[N],top[N],fa[N]; int sum[N],ma[N],w[N]; inline void r(int &x){ bool sign=1; x=0; char ch=getchar(); while(ch<'0'||ch>'9') ch=getchar(); if(ch=='-') sign=0,ch=getchar(); while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+(ch^48),ch=getchar(); x=sign?x:-x; } inline void add(int x,int y){ to[++tot]=y,nex[tot]=fir[x],fir[x]=tot; to[++tot]=x,nex[tot]=fir[y],fir[y]=tot; } inline void dfs1(int u,int f){ int i,v; size[u]=1; fa[u]=f; dep[u]=dep[f]+1; for(i=fir[u];v=to[i],i;i=nex[i]){ if(v!=f){ dfs1(v,u); size[u]+=size[v]; if(size[v]>size[son[u]]) son[u]=v; } } } inline void dfs2(int u,int f){ int i,v; if(son[u]){ seg[son[u]]=++tim; rev[tim]=son[u]; top[son[u]]=top[u]; dfs2(son[u],u); } for(i=fir[u];v=to[i],i;i=nex[i]) if(!top[v]){ seg[v]=++tim; rev[tim]=v; top[v]=v; dfs2(v,u); } } inline void build(int k,int l,int r){ if(l==r){ sum[k]=ma[k]=w[l]; return; } int mid=l+r>>1; build(k<<1,l,mid); build(k<<1|1,mid+1,r); ma[k]=max(ma[k<<1],ma[k<<1|1]); sum[k]=sum[k<<1]+sum[k<<1|1]; } inline void change(int k,int l,int r,int val,int pos){ if(l>pos||r<pos) return; if(l==r&&l==pos){ ma[k]=sum[k]=val; return; } int mid=l+r>>1; change(k<<1,l,mid,val,pos); change(k<<1|1,mid+1,r,val,pos); sum[k]=sum[k<<1]+sum[k<<1|1]; ma[k]=max(ma[k<<1],ma[k<<1|1]); } inline void query(int k,int x,int y,int l,int r){ if(x>r||y<l) return; if(x>=l&&y<=r){ SUM+=sum[k]; MAX=max(ma[k],MAX); return; } int mid=x+y>>1; query(k<<1,x,mid,l,r); query(k<<1|1,mid+1,y,l,r); } inline void ask(int x,int y){ int fx=top[x],fy=top[y]; while(fx!=fy){ if(dep[fx]<dep[fy]) swap(x,y),swap(fx,fy); query(1,1,tim,seg[fx],seg[x]); x=fa[x],fx=top[x]; } if(dep[x]>dep[y]) swap(x,y); query(1,1,tim,seg[x],seg[y]); } int main() { int i,j,x,y; char op[10]; r(n); for(i=1;i<n;i++){ r(x),r(y); add(x,y); } for(i=1;i<=n;i++) r(w[i]); tim=seg[1]=top[1]=rev[1]=1; dfs1(1,0); dfs2(1,0); build(1,1,tim); r(m); for(i=1;i<=m;i++){ scanf("%s",op); r(x),r(y); SUM=0; MAX=-N; switch(op[1]){ case 'M':{ ask(x,y); printf("%d\n",MAX); break; } case 'S':{ ask(x,y); printf("%d\n",SUM); break; } case 'H':{ change(1,1,tim,y,seg[x]); break; } } } }
2019-09-04