介绍
树链剖分是一种比较高级的数据结构,这个数据结构就可以从名字上看出,实在树上进行操作的,那么这个数据结构到底是干什么的呢。
我们要先明确出这种数据结构所要解决的问题。
以下是洛谷中模板题的题目 传送门
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
问题来源
我们首先分析一下在暴力情况下每个操作的复杂度是怎么样的。
操作一显然是
O
(
n
)
O(n)
O(n)可以解决的。
操作二也是需要
O
(
n
)
O(n)
O(n)来解决的。
操作三是
O
(
n
)
O(n)
O(n)的
操作四四是
O
(
n
)
O(n)
O(n)的
然后看一眼数据范围,大仙这种问题解决不了,所以我们就需要一种高级的数据结构,就是树链剖分了。
然后呢,我们来看一下这个数据结构到底是个什么玩意,它是如何运行的。
名词解释
首先我们要明确一些名词的概念。
重儿子:子树结点数目最多的结点;
轻儿子:父亲节点中除了重结点以外的结点;
重边:父亲结点和重结点连成的边;
轻边:父亲节点和轻节点连成的边;
重链:由多条重边连接而成的路径;
轻链:由多条轻边连接而成的路径;
理解
为了理解这么一些概念,我们先画一颗树。
然后我们把它们的size求出来(size就是他的重量(重量就是里面几个儿子))
然后我们就会发现每个节点的重儿子。
然后我们将这个节点和他的重儿子连一条边,叫重边
然后我们就要重新编号了,就是要先编重儿子的号,再编轻儿子的号
然后就是对每一条重链建线段树了
代码实现
这是存东西用的数组
然后就是要将这棵树存下来
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
point[x]=cnt;
return;
}
然后我们就对这颗树进行第一次dfs,这次dfs主要是求出他的size,来选出他的重儿子,来连重边
void dfs1(int x,int dep)
{
vis[x]=1;
deep[x]=dep;
size[x]=1;
int maxson=-1;
for(int i=point[x];i;i=nxt[i])
{
int p=to[i];
if(!vis[p])
{
fa[p]=x;
dfs1(p,dep+1);
size[x]+=size[p];
if(maxson<size[p])
{
maxson=size[p];
son[x]=p;
}
}
}
return;
}
然后我们就是要进行第二步,就是要再dfs一遍,来求出这个节点所在重链的最上端,和自己重新编号后的编号
void dfs2(int x,int chain)//chain是重链的祖先
{
newnum[x]=++tot;
newval[tot]=val[x];
top[x]=chain;
if(!son[x])return;
dfs2(son[x],chain);
for(int i=point[x];i;i=nxt[i])
{
int p=to[i];
if(!newnum[p])
{
dfs2(p,p);
}
}
return;
}
接下来就是要建线段树。
void build(int tr,int l,int r)
{
if(l==r)
{
f[tr]=newval[l];
return;
}
int mid=(l+r)>>1;
build(2*tr,l,mid);
build(2*tr+1,mid+1,r);
up(tr);
return;
}
pushup和pushdown都是和线段树一样的
void up(int tr)
{
f[tr]=(f[tr<<1]+f[tr*2+1])%mod;
return;
}
void pushdown(int tr,int l,int r)
{
if(add[tr])
{
add[tr<<1]=(add[tr<<1]+add[tr])%mod;
add[tr*2+1]=(add[tr*2+1]+add[tr])%mod;
int mid=(l+r)>>1;
f[tr<<1]=(f[tr<<1]+(mid-l+1)*add[tr]%mod)%mod;
f[tr*2+1]=(f[tr*2+1]+(r-mid)*add[tr]%mod)%mod;
add[tr]=0;
}
return;
}
然后就是和线段树一样的求和和修改操作了。
void update(int tr,int l,int r,int x,int y,int p)
{
if(x<=l&&r<=y)
{
f[tr]=(f[tr]+(r-l+1)*p%mod)%mod;
add[tr]=(add[tr]+p)%mod;
return;
}
pushdown(tr,l,r);
int mid=(l+r)>>1;
if(x<=mid)update(tr<<1,l,mid,x,y,p);
if(y>mid)update(tr*2+1,mid+1,r,x,y,p);
up(tr);
return;
}
int query(int tr,int l,int r ,int x,int y)
{
if(x<=l&&r<=y)
{
return f[tr]%mod;
}
pushdown(tr,l,r);
int mid=(l+r)>>1;
int ans=0;
if(x<=mid)ans=(ans+query(tr*2,l,mid,x,y))%mod;
if(y>mid)ans=(ans+query(tr*2+1,mid+1,r,x,y))%mod;
return ans;
}
然后呢,就是奇怪的求和和修改方式了。
因为我们这个树状的东西,要在上面建一颗线段树,我们在修改和查询的时候就有点麻烦,每次需要更换链进行区间的查询和修改,就需要一个新的函数了。
void treeadd(int x,int y,int p)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])swap(x,y);
update(1,1,n,newnum[top[x]],newnum[x],p);
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
update(1,1,n,newnum[x],newnum[y],p);
return;
}
因为树剖就像倍增求lca一样,需要不断跳来跳去所以也可以用树剖求lca
这个是查询答案的
int treesum(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])swap(x,y);
ans=(ans+query(1,1,n,newnum[top[x]],newnum[x]))%mod;
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
ans=(ans+query(1,1,n,newnum[x],newnum[y]))%mod;
return ans;
}
总共的代码如下:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
inline int read()
{
int x=0,f=1;char ch=getchar();
while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
return x*f;
}
const int maxn=1000100;
const int maxm=1000010;
int point[maxn],top[maxn],son[maxn],deep[maxn];
int fa[maxn],newnum[maxn],newval[maxn];
int f[4*maxn],add[4*maxn];
int nxt[maxn],to[maxn];
int vis[maxn];
int val[maxn];
int n,m;
int roo,mod;
int tot;
int size[maxn];
int cnt;
void addedge(int x,int y)
{
nxt[++cnt]=point[x];
to[cnt]=y;
point[x]=cnt;
return;
}
void up(int tr)
{
f[tr]=(f[tr<<1]+f[tr*2+1])%mod;
return;
}
void dfs1(int x,int dep)
{
vis[x]=1;
deep[x]=dep;
size[x]=1;
int maxson=-1;
for(int i=point[x];i;i=nxt[i])
{
int p=to[i];
if(!vis[p])
{
fa[p]=x;
dfs1(p,dep+1);
size[x]+=size[p];
if(maxson<size[p])
{
maxson=size[p];
son[x]=p;
}
}
}
return;
}
void dfs2(int x,int chain)//chain是重链的祖先
{
newnum[x]=++tot;
newval[tot]=val[x];
top[x]=chain;
if(!son[x])return;
dfs2(son[x],chain);
for(int i=point[x];i;i=nxt[i])
{
int p=to[i];
if(!newnum[p])
{
dfs2(p,p);
}
}
return;
}
void pushdown(int tr,int l,int r)
{
if(add[tr])
{
add[tr<<1]=(add[tr<<1]+add[tr])%mod;
add[tr*2+1]=(add[tr*2+1]+add[tr])%mod;
int mid=(l+r)>>1;
f[tr<<1]=(f[tr<<1]+(mid-l+1)*add[tr]%mod)%mod;
f[tr*2+1]=(f[tr*2+1]+(r-mid)*add[tr]%mod)%mod;
add[tr]=0;
}
return;
}
void build(int tr,int l,int r)
{
if(l==r)
{
f[tr]=newval[l];
return;
}
int mid=(l+r)>>1;
build(2*tr,l,mid);
build(2*tr+1,mid+1,r);
up(tr);
return;
}
void update(int tr,int l,int r,int x,int y,int p)
{
if(x<=l&&r<=y)
{
f[tr]=(f[tr]+(r-l+1)*p%mod)%mod;
add[tr]=(add[tr]+p)%mod;
return;
}
pushdown(tr,l,r);
int mid=(l+r)>>1;
if(x<=mid)update(tr<<1,l,mid,x,y,p);
if(y>mid)update(tr*2+1,mid+1,r,x,y,p);
up(tr);
return;
}
int query(int tr,int l,int r ,int x,int y)
{
if(x<=l&&r<=y)
{
return f[tr]%mod;
}
pushdown(tr,l,r);
int mid=(l+r)>>1;
int ans=0;
if(x<=mid)ans=(ans+query(tr*2,l,mid,x,y))%mod;
if(y>mid)ans=(ans+query(tr*2+1,mid+1,r,x,y))%mod;
return ans;
}
void treeadd(int x,int y,int p)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])swap(x,y);
update(1,1,n,newnum[top[x]],newnum[x],p);
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
update(1,1,n,newnum[x],newnum[y],p);
return;
}
int treesum(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]])swap(x,y);
ans=(ans+query(1,1,n,newnum[top[x]],newnum[x]))%mod;
x=fa[top[x]];
}
if(deep[x]>deep[y])swap(x,y);
ans=(ans+query(1,1,n,newnum[x],newnum[y]))%mod;
return ans;
}
int main()
{
n=read();m=read();roo=read();mod=read();
for(int i=1;i<=n;i++)
val[i]=read();
for(int i=1;i<n;i++)
{
int x,y;
x=read();
y=read();
addedge(x,y);
addedge(y,x);
}
dfs1(roo,1);
dfs2(roo,roo);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int zhi,x,y,z;
zhi=read();
if(zhi==1)
{
x=read();y=read();z=read()%mod;
treeadd(x,y,z);
}
if(zhi==2)
{
x=read();y=read();
printf("%d\n",treesum(x,y));
}
if(zhi==3)
{
x=read();y=read()%mod;
update(1,1,n,newnum[x],newnum[x]+size[x]-1,y);
}
if(zhi==4)
{
x=read();
printf("%d\n",query(1,1,n,newnum[x],newnum[x]+size[x]-1)%mod);
}
}
return 0;
}