树链剖分,顾名思义,就是将一棵树上的节点按照一个特殊的方式重新编号,这样我们就可以利用一些数据结构去优化加速一些树上的操作;
现在要介绍的是重链剖分;
首先明确一些概念:
重儿子:父亲节点的所有儿子中子树结点数目最多(size最大)的结点;
轻儿子:父亲节点中除了重儿子以外的儿子;
重边:父亲结点和重儿子连成的边;
轻边:父亲节点和轻儿子连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
有了这些概念,我们就可以愉快地剖分了;
具体操作就是先用两个 dfs作出以下变量:
名称 | 解释 |
fa[u] | 保存结点u的父亲节点 |
dep[u] | 保存结点u的深度值 |
size[u] | 保存以u为根的子树节点个数 |
son[u] | 保存重儿子 |
rk[u] | 保存当前dfs标号在树中所对应的节点 |
top[u] | 保存当前节点所在链的顶端节点 |
dfn[u] | 保存树中每个节点剖分以后的新编号(DFS的执行顺序) |
然后在写一棵线段树(某数据结构),将树上节点以 dfs序映射到线段上,然后就可以优化树上操作了,这就是树链剖分;
附上代码:
#include<cstdio> #include<cstring> #include<iostream> #include<algorithm> using namespace std; const int N = 1e5+10; int n,m,r,mod; int val[N]; int dfn[N],top[N],rk[N]; int dep[N],fa[N],size[N],son[N]; struct node{ int l,r,ls,rs,sum,lazy; }a[N<<4]; struct edge{ int next,to; }e[N<<4]; int head[N],cnt; void add(int u,int v){ e[++cnt]=(edge){head[u],v}; head[u]=cnt; } void dfs1(int x){ size[x]=1; dep[x]=dep[fa[x]]+1; for(int v,i=head[x];i;i=e[i].next){ v=e[i].to; if(dep[v]) continue; fa[v]=x; dfs1(v); size[x]+=size[v]; if(size[v]>size[son[x]]) son[x]=v; } } void dfs2(int x,int t){ top[x]=t; dfn[x]=++cnt; rk[cnt]=x; if(son[x]) dfs2(son[x],t); for(int i=head[x],v;i;i=e[i].next){ v=e[i].to; if(v==fa[x]) continue; if(son[x]!=v) dfs2(v,v); } } void pushup(int o){a[o].sum=(a[a[o].ls].sum+a[a[o].rs].sum)%mod;} void build(int o,int l,int r){ if(l==r){ a[o].sum=val[rk[l]]; a[o].l=a[o].r=l; return ; } int mid=(l+r)>>1; a[o].ls=++cnt,a[o].rs=++cnt; build(a[o].ls,l,mid); build(a[o].rs,mid+1,r); a[o].l=a[a[o].ls].l; a[o].r=a[a[o].rs].r; pushup(o); } void pushdown(int o){ if(a[o].lazy){ int ls=a[o].ls,rs=a[o].rs; a[ls].lazy=(a[ls].lazy+a[o].lazy)%mod; a[rs].lazy=(a[rs].lazy+a[o].lazy)%mod; a[ls].sum=(a[ls].sum+(a[ls].r-a[ls].l+1)*a[o].lazy)%mod; a[rs].sum=(a[rs].sum+(a[rs].r-a[rs].l+1)*a[o].lazy)%mod; a[o].lazy=0; } } void updata(int o,int x,int y,int d){ if(a[o].l>=x&&a[o].r<=y){ a[o].lazy+=d; a[o].sum=(a[o].sum+(a[o].r-a[o].l+1)*d)%mod; return ; } pushdown(o); int mid=(a[o].l+a[o].r)>>1; if(x<=mid) updata(a[o].ls,x,y,d); if(y>mid) updata(a[o].rs,x,y,d); pushup(o); } int query(int o,int x,int y){ if(a[o].l>=x&&a[o].r<=y) return a[o].sum; pushdown(o); int mid=(a[o].l+a[o].r)>>1; int rel=0; if(x<=mid) rel=(rel+query(a[o].ls,x,y))%mod; if(y>mid) rel=(rel+query(a[o].rs,x,y))%mod; return rel; } int getsum(int x,int y){ int rel=0; while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); rel=(rel+query(1,dfn[top[x]],dfn[x]))%mod; x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); rel=(rel+query(1,dfn[x],dfn[y]))%mod; return rel; } int change(int x,int y,int d){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); updata(1,dfn[top[x]],dfn[x],d); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); updata(1,dfn[x],dfn[y],d); } int main() { scanf("%d%d%d%d",&n,&m,&r,&mod); for(int i=1;i<=n;++i) scanf("%d",&val[i]); for(int i=1,x,y;i<n;++i){ scanf("%d%d",&x,&y); add(x,y); add(y,x); } cnt=0,dfs1(r),dfs2(r,r); build(1,1,n); for(int i=1,op,x,y,z;i<=m;++i){ scanf("%d",&op); if(op==1){ scanf("%d%d%d",&x,&y,&z); change(x,y,z); } if(op==2){ scanf("%d%d",&x,&y); printf("%d\n",getsum(x,y)); } if(op==3){ scanf("%d%d",&x,&z); updata(1,dfn[x],dfn[x]+size[x]-1,z); } if(op==4){ scanf("%d",&x); printf("%d\n",query(1,dfn[x],dfn[x]+size[x]-1)); } } return 0; }