本文除代码外均参考或引自oi wiki上相关内容(树链剖分 - OI Wiki)
树链剖分的思想、能解决的问题
给定一棵静态(形状固定的)树,要求进行几种操作:
1、修改单个节点/树上两点之间的路径/一个节点的子树上的所有点的值。
2、查询单个节点/树上两点之间的路径/一个节点的子树上节点的值的和/极值/其他具有较强的合并性的信息。
如果树的形态是一条链,那么我们只需要维护一个线段树,通过修改或查询线段树的值即可得到相关答案;但由于需要处理的是一棵树,我们考虑将这棵树剖分成多个链,并用线段树修改或查询答案,这就是树链剖分的思想。
剖分及相关修改/查询操作
规定对于树上任意节点 x x x:
1、 x x x的重儿子 s o n ( x ) son(x) son(x):所有子节点中子树最大(子树节点最多)的节点;
2、重边:连接 x x x与 s o n ( x ) son(x) son(x)的边;
3、轻边:连接 x x x与除 s o n ( x ) son(x) son(x)外其它子节点的边;
4、重链(重路径):相连通的重边连成的链。
我们知道, d f s dfs dfs序可以保证某个节点及其子树内所有节点的 d f s dfs dfs序在且占满一个连续的 d f s dfs dfs序值的区间,那么,我们预处理出所有节点的重儿子后,再次dfs整棵树,此时搜到某个节点后规定之后先搜这个节点的重儿子,这样就可以保证任意一条重链上所有的节点对应的 d f s dfs dfs序在且占满一个连续的 d f s dfs dfs序值的区间,而且深度越深的节点对应 d f s dfs dfs序显然越大。这样我们便可以通过线段树修改相关信息。
void dfs1(int x,int father)//预处理部分信息
{
fa[x]=father,dep[x]=dep[fa[x]]+1,siz[x]=1;
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=fa[x])
{
dfs1(v[k],x);
siz[x]+=siz[v[k]];
if(siz[v[k]]>valt[x])
valt[x]=siz[v[k]],son[x]=v[k];
}
}
void dfs2(int x,int anc)//剖分轻重链
{
top[x]=anc;//dfn[x]=++cnt; rk[cnt]=x;
if(son[x]!=0) dfs2(son[x],anc);//此处if是为了防止dfs2爆栈
for(ri k=fst[x];k>0;k=nxt[k])
if((v[k]!=fa[x])&&(v[k]!=son[x])) dfs2(v[k],v[k]);
}
对于文章开头叙述的相关操作,我们需要解决在于如何维护或查询两点间路径相关信息,这就引出如何在被轻重链剖分后的树上求LCA的问题。
考虑我们是如何用 倍增法求解 LCA 的。首先我们 将两个节点提到同一高度,然后将两个节点一起向上跳 。对于树链剖分也可以使用这样的思想。在向上跳的过程中,如果当前节点在重链上,则令其跳到重链顶端,如果当前节点不在重链上,经过轻边向上跳一个节点。如此直到两节点跳到同一条链上。
int LCA(int x,int y)
{
while(top[x]!=top[y])//top[x]:x所在链的顶端的点的序号
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];//每次循环将所在链链顶深度较大的点往上跳,直到两个点跳到同一条链为止
}
return (dep[x]<dep[y])?x:y;//两个点跳到同一条链时,深度较小的点即为LCA(x,y)
}
对于修改或查询操作,我们只需要沿途利用线段树对每一段跳过的重链或部分重链对应的 d f s dfs dfs序区间更新/查询区间信息即可。
时间复杂度及证明
以上算法的时间复杂度为 O ( q l o g 2 n ) O(qlog^2n) O(qlog2n),其中 q q q为修改或查询操作的此时, n n n为树的节点个数,下证:
可以证明,如果 u u u是 v v v的父亲且 v v v不是 u u u的重儿子,有 2 s i z ( v ) ≤ s i z ( u ) 2siz(v)\leq siz(u) 2siz(v)≤siz(u)(反之, v v v一定是 u u u的重儿子,矛盾)。由此可知,在求LCA的过程中,我们每一次将某个节点 u u u O ( 1 ) O(1) O(1)提到链顶且通过轻边向上跳到另一条重链上的某个节点 x x x时,一定有 2 s i z ( u ) ≤ 2 s i z ( t o p u ) ≤ s i z ( x ) 2siz(u)\leq2siz(top_u)\leq siz(x) 2siz(u)≤2siz(topu)≤siz(x),即:如果我们顺着重链或轻边向上跳,新的点的子树节点个数相比跳之前前的节点至少乘以2子树节点个数最大为 n n n,最多跳 l o g 2 n log_2n log2n次,故树链剖分求任意两点 L C A LCA LCA的复杂度为 O ( l o g n ) O(logn) O(logn)。再乘上跳的过程中线段树对每一段跳过的重链或部分重链修改或查询的时间复杂度 O ( l o g n ) O(logn) O(logn)和操作数 q q q,最后的时间复杂度为 O ( q l o g 2 n ) O(qlog^2n) O(qlog2n),证毕。
Code
P3379 【模板】最近公共祖先(LCA)
#include<cstdio>
#include<iostream>
#define ri register int
using namespace std;
const int MAXN=5e5+20;
int N,Q,R,M,u[MAXN<<1],v[MAXN<<1],fst[MAXN<<1],nxt[MAXN<<1],xi,yi;
int dep[MAXN],fa[MAXN],siz[MAXN],valt[MAXN],son[MAXN],top[MAXN];
void dfs1(int x,int father)//预处理部分信息
{
fa[x]=father,dep[x]=dep[fa[x]]+1,siz[x]=1;
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=fa[x])
{
dfs1(v[k],x);
siz[x]+=siz[v[k]];
if(siz[v[k]]>valt[x])
valt[x]=siz[v[k]],son[x]=v[k];
}
}
void dfs2(int x,int anc)//剖分轻重链
{
top[x]=anc;//dfn[x]=++cnt; rk[cnt]=x;
if(son[x]!=0) dfs2(son[x],anc);//此处if是为了防止dfs2爆栈
for(ri k=fst[x];k>0;k=nxt[k])
if((v[k]!=fa[x])&&(v[k]!=son[x])) dfs2(v[k],v[k]);
}
int LCA(int x,int y)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];//每次循环将所在链链顶深度较大的点往上跳,直到两个点跳到同一条链为止
}
return (dep[x]<dep[y])?x:y;//两个点跳到同一条链时,深度较小的点即为LCA(x,y)
}
int main()
{
scanf("%d%d%d",&N,&Q,&R);
M=(N-1)<<1;
for(ri i=1;i<=M;i+=2)
{
scanf("%d%d",&u[i],&v[i]);
nxt[i]=fst[u[i]],fst[u[i]]=i;
u[i+1]=v[i],v[i+1]=u[i];
nxt[i+1]=fst[u[i+1]],fst[u[i+1]]=i+1;
}
dep[0]=-1;
dfs1(R,0);
dfs2(R,R);
for(ri op=1;op<=Q;++op)
{
scanf("%d%d",&xi,&yi);
cout<<LCA(xi,yi)<<'\n';
}
return 0;
}
P3384 【模板】轻重链剖分
#include<cstdio>
#include<iostream>
#define ri register int
#define ll long long
using namespace std;
const int MAXN=1e5+20;
int N,Q,R,M,u[MAXN<<1],v[MAXN<<1],fst[MAXN<<1],nxt[MAXN<<1],opt,xi,yi;
int dep[MAXN],fa[MAXN],siz[MAXN],valt[MAXN],son[MAXN],top[MAXN],dfn[MAXN],rk[MAXN],cnt;
int l[MAXN<<2],r[MAXN<<2],len[MAXN<<2];
ll MOD,a[MAXN],zi,sum[MAXN<<2],add[MAXN<<2];
void dfs1(int x,int father)
{
fa[x]=father,dep[x]=dep[fa[x]]+1,siz[x]=1;
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=fa[x])
{
dfs1(v[k],x);
siz[x]+=siz[v[k]];
if(siz[v[k]]>valt[x])
valt[x]=siz[v[k]],son[x]=v[k];
}
}
void dfs2(int x,int anc)
{
top[x]=anc,dfn[x]=++cnt; rk[dfn[x]]=x;
if(son[x]!=0) dfs2(son[x],anc);
for(ri k=fst[x];k>0;k=nxt[k])
if((v[k]!=fa[x])&&(v[k]!=son[x])) dfs2(v[k],v[k]);
}
void pushup(int p)
{
sum[p]=(sum[p <<1]+sum[p <<1|1])%MOD;
}
void pushdown(int p)
{
sum[p <<1]=(sum[p <<1]+add[p]*len[p <<1]%MOD)%MOD;
add[p <<1]=(add[p <<1]+add[p])%MOD;
sum[p <<1|1]=(sum[p <<1|1]+add[p]*len[p <<1|1]%MOD)%MOD;
add[p <<1|1]=(add[p <<1|1]+add[p])%MOD;
add[p]=0;
}
void build(int p,int lft,int rit)
{
l[p]=lft,r[p]=rit,len[p]=rit-lft+1;
if(l[p]==r[p])
{
sum[p]=a[rk[l[p]]];//注意;建树时子节点的值需利用rk数组查询dfn为l[p]时对应节点的编号
return;
}
ri mid=(lft+rit)>>1;
build(p <<1,lft,mid); build(p <<1|1,mid+1,rit);
pushup(p);
}
void update(int p,int lft,int rit,ll k)
{
if(lft<=l[p]&&r[p]<=rit)
{
sum[p]=(sum[p]+len[p]*k%MOD)%MOD,add[p]=(add[p]+k)%MOD;
return;
}
pushdown(p);
if(lft<=r[p <<1]) update(p <<1,lft,rit,k);
if(l[p <<1|1]<=rit) update(p <<1|1,lft,rit,k);
pushup(p);
}
ll query(int p,int lft,int rit)
{
if(lft<=l[p]&&r[p]<=rit) return sum[p];
pushdown(p);
ll ans=0;
if(lft<=r[p <<1]) ans=query(p <<1,lft,rit);
if(l[p <<1|1]<=rit) ans+=query(p <<1|1,lft,rit);
return ans%MOD;
}
void LCAu(int x,int y,ll k)//修改u->v间最短路径点权
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,dfn[top[x]],dfn[x],k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,dfn[x],dfn[y],k);
return;
}
ll LCAq(int x,int y)//查询u->v间最短路径点权和
{
ll ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x]))%MOD;
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=query(1,dfn[x],dfn[y]);
return ans%MOD;
}
int main()
{
scanf("%d%d%d%lld",&N,&Q,&R,&MOD);
for(ri i=1;i<=N;++i) scanf("%lld",&a[i]);
M=(N-1)<<1;
for(ri i=1;i<=M;i+=2)
{
scanf("%d%d",&u[i],&v[i]);
nxt[i]=fst[u[i]],fst[u[i]]=i;
u[i+1]=v[i],v[i+1]=u[i];
nxt[i+1]=fst[u[i+1]],fst[u[i+1]]=i+1;
}
dep[0]=-1;
dfs1(R,0);
dfs2(R,R);
build(1,1,N);
for(ri op=1;op<=Q;++op)
{
scanf("%d",&opt);
if(opt==1)
{
scanf("%d%d%lld",&xi,&yi,&zi);
LCAu(xi,yi,zi);
}
if(opt==2)
{
scanf("%d%d",&xi,&yi);
cout<<LCAq(xi,yi)<<'\n';
}
if(opt==3)
{
scanf("%d%lld",&xi,&zi);
update(1,dfn[xi],dfn[xi]+siz[xi]-1,zi);
}
if(opt==4)
{
scanf("%d",&xi);
cout<<query(1,dfn[xi],dfn[xi]+siz[xi]-1)<<'\n';
}
}
return 0;
}