树链剖分——杨子曰算法
先搞一道题,给一个n个结点的树(n<=30000),每个结点上有一个权值,然后是m个询问(m<=200000),对于每个询问k,a,b,如果k为1,则输出从a到b路径上的和,如果k为2,则将结点a的权值变为b
今天我们来曰树链剖分,听听这名字,就透露着一种大佬的气息。看到题目,你要想到一个东东——LCA(不知道它是神马鬼的童鞋戳我),倍增的去求和是很好的想法,但还要update的事实告诉你现实的残酷(TLE)。
你还要想到另一个东东——线段树(不知道它是神马鬼的童鞋请停止阅读这篇文章,然后戳我),BBUUTT,线段树做的是区间,要建立在一维(二维也可以)的链上,这里是一棵树,肿么办——树链剖分就是把一棵树咔嚓咔嚓地砍几下,变成几条链,然后在线段树搞一搞,完事!我们开始:
对于一棵树,我们要把它砍成几条链,砍成的链越少越好,链越少复杂度越低,这一点, 地球人都知道,至于怎样把一棵树砍成尽量少尽量的链呢?先look at the 图:
这是一棵树(废话):
这是被我们剖分(砍)完以后的树:
总共是三条链,到底怎么分呢?
首先你要知道一个东西叫——重儿子(有人说:就是所有儿子中体重最重的),就是这个点的所有儿子中拥有子树最大的
比如上面那个图:2就是1的重儿子,5就是2的重儿子
这样一来我们就可以把链的一端从更结点出发,向重儿子延续这条链,在用重儿子继续DFS下去,至于其他的儿子则作为新的一条链的链顶,用同样的方法DFS下去,我们就把一棵树剖分好了,但代码打起来,没有你想象的简单,你需要写两个DFS,之前也说了,我们一会要把这棵树投射到线段树上,所以我们要把树上的每个节点重新编号——这是用来做线段树的编号,注意,编号时,同一条链上的点编号需要是连续的(待会你就知道为什么了)。这比较简单只要在剖分找重儿子的时候顺便编个号就好了
现在,我们开始讲两个DFS:
- 第一个DFS
你需要求:
- f[i] (结点i的父亲) 作用:用于DFS灌水以及后面求和
- dep[i] (结点i的深度) 作用:后面求和
- sz[i] (以i为根节点的这棵子树的大小(节点数)) 作用求son[i]
- son[i] (结点i的重儿子) 作用:第二个的DFS进行树链剖分
我觉得第一个DFS特别简单,直接代码走起:
void dfs1(int v,int fa,int d){
sz[v]=1;
f[v]=fa;
dep[v]=d;
for (int i=head[v];i!=-1;i=edge[i].next){
int u=edge[i].to;
if (u==fa) continue;
dfs1(u,v,d+1);
sz[v]+=sz[u];
if (!son[v] || sz[u]>sz[son[v]]) son[v]=u;
}
}
- 第二个DFS
你需要求:- top[i](结点i所在这条链的链顶)
- id[i] (结点i在线段树的编号)
写代码时我们把重儿子单独写出来,看看代码应该就懂了,走起:
void dfs2(int v,int fa,int tp){
top[v]=tp;
id[v]=++sum;
if (son[v]) dfs2(son[v],v,tp);
for (int i=head[v];i!=-1;i=edge[i].next){
int u=edge[i].to;
if (u==fa || u==son[v]) continue;
dfs2(u,v,u);
}
}
哦,历经千辛万苦,我们终于把这道题的准备工作做完了,这棵树已经被我们砍成几条链了,接下来就是求和,相信有很多大佬在前面的时候就已经有疑问了,如果要询问的两个点在同一条链上,那结点编号就是连续的,放在线段树上搞一搞,很简单,BUT更多的时候询问的两个点不在同一条链上呀,那肿么办?杨子曰:别急,马上就明白了
假设我们现在要求两个不在同一条链上的两个节点x,y路径上的和
1.只要x,y不在同一条链上,就重复2,否则进行3:
2.将x,y中所在链的链顶较低的点,跳到这条链的链顶的父亲,并算出跳跃部分的和 (这段和是可以算的,因为它在一条链上,链上的点在线段树上是一个区间)
3.现在x,y在一条链上了,这就意味着这段路径的和可以算了,完事!
我们搞个实例模拟一下:
现在x,y不在同一条链上,比较下绿色链和蓝色链的链顶,发现x所在的绿色链链顶低(深度大),我们把x跳到链顶的父亲结点1,算出S1部分的和,再比较下红色链和蓝色链的链顶,发现y所在的蓝色链链顶低,我们把y跳到链顶的父亲结点2,算出S2部分的和,现在x,y(分别在1,2)都在同一条红色链上了,算出S3,欧了。
这里线段树的操作asksum,update就不多说了自己戳
提一下:如果你想路径update也可以,你就需要像算和那样跳着在线段树上更新
代码走起:
long long Qsum(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
int f1=top[x],f2=top[y];
if(dep[f1]<dep[f2]){
ans+=asksum(1,n,id[f2],id[y],1);
y=f[f2];
}
else{
ans+=asksum(1,n,id[f1],id[x],1);
x=f[f1];
}
}
if (dep[x]>dep[y]) swap(x,y);
ans+=asksum(1,n,id[x],id[y],1);
return ans;
}
如果你还知道神马是LCA的话(不知道,戳我),你就可以用树链剖分求LCA(←它比倍增更快更优)
OK,完事
完整模板C++代码(HYSBZ - 1036)
#include<bits/stdc++.h>
using namespace std;
struct Edge{
int next,to;
}edge[120005];
int sum=0,nedge=0,n;
int head[60005],sz[60005],f[60005],dep[60005],son[60005],top[60005],id[60005],w[60005];
int maxx[240005];
long long summ[240005];
void addedge(int a,int b){
edge[nedge].to=b;
edge[nedge].next=head[a];
head[a]=nedge++;
}
void dfs1(int v,int fa,int d){
sz[v]=1;
f[v]=fa;
dep[v]=d;
for (int i=head[v];i!=-1;i=edge[i].next){
int u=edge[i].to;
if (u==fa) continue;
dfs1(u,v,d+1);
sz[v]+=sz[u];
if (!son[v] || sz[u]>sz[son[v]]) son[v]=u;
}
}
void dfs2(int v,int fa,int tp){
top[v]=tp;
id[v]=++sum;
if (son[v]) dfs2(son[v],v,tp);
for (int i=head[v];i!=-1;i=edge[i].next){
int u=edge[i].to;
if (u==fa || u==son[v]) continue;
dfs2(u,v,u);
}
}
void pushup(int nod){
maxx[nod]=max(maxx[nod*2],maxx[nod*2+1]);
summ[nod]=summ[nod*2]+summ[nod*2+1];
}
void update(int l,int r,int k,int v,int nod){
if (l==r){
maxx[nod]=v;
summ[nod]=v;
return;
}
int mid=(l+r)/2;
if (k<=mid) update(l,mid,k,v,nod*2);
else update(mid+1,r,k,v,nod*2+1);
pushup(nod);
}
int askmax(int l,int r,int ll,int rr,int nod){
if (l==ll && r==rr) return maxx[nod];
int mid=(l+r)/2;
if (rr<=mid) return askmax(l,mid,ll,rr,nod*2);
else if (ll>mid) return askmax(mid+1,r,ll,rr,nod*2+1);
else return max(askmax(l,mid,ll,mid,nod*2),askmax(mid+1,r,mid+1,rr,nod*2+1));
}
int Qmax(int x,int y){
int ans=-200000000;
while(top[x]!=top[y]){
int f1=top[x],f2=top[y];
if (dep[f1]<dep[f2]) {
ans=max(ans,askmax(1,n,id[f2],id[y],1));
y=f[f2];
}
else{
ans=max(ans,askmax(1,n,id[f1],id[x],1));
x=f[f1];
}
}
if (dep[x]>dep[y]) swap(x,y);
ans=max(ans,askmax(1,n,id[x],id[y],1));
return ans;
}
long long asksum(int l,int r,int ll,int rr,int nod){
if (l==ll && r==rr) return summ[nod];
int mid=(l+r)/2;
if (rr<=mid) return asksum(l,mid,ll,rr,nod*2);
else if (ll>mid) return asksum(mid+1,r,ll,rr,nod*2+1);
else return asksum(l,mid,ll,mid,nod*2)+asksum(mid+1,r,mid+1,rr,nod*2+1);
}
long long Qsum(int x,int y){
long long ans=0;
while(top[x]!=top[y]){
int f1=top[x],f2=top[y];
if(dep[f1]<dep[f2]){
ans+=asksum(1,n,id[f2],id[y],1);
y=f[f2];
}
else{
ans+=asksum(1,n,id[f1],id[x],1);
x=f[f1];
}
}
if (dep[x]>dep[y]) swap(x,y);
ans+=asksum(1,n,id[x],id[y],1);
return ans;
}
int main(){
memset(head,-1,sizeof(head));
scanf("%d",&n);
for (int i=1;i<n;i++){
int a,b;
scanf("%d%d",&a,&b);
addedge(a,b);
addedge(b,a);
}
dfs1(1,0,1);
dfs2(1,0,1);
for (int i=1;i<=n;i++){
int x;
scanf("%d",&x);
update(1,n,id[i],x,1);
}
int m;
scanf("%d",&m);
while(m--){
char s[10]; scanf("%s",s);
if (strcmp(s,"QMAX")==0){
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",Qmax(x,y));
}
if (strcmp(s,"QSUM")==0){
int x,y;
scanf("%d%d",&x,&y);
printf("%lld\n",Qsum(x,y));
}
if (strcmp(s,"CHANGE")==0){
int x,y;
scanf("%d%d",&x,&y);
update(1,n,id[x],y,1);
}
}
return 0;
}
于 XJZX 507机房
未经作者允许,严禁转载:https://blog.csdn.net/HenryYang2018/article/details/81000472