说一下树链剖分,一开始听这个名字感觉像是树上结构一样。结果是个数据结构。要理解这个数据结构我们先来了解几个概念:
树链剖分这里我们主要运用的是重链剖分。也就是说我们记录每个节点代表子树的size。每个节点向它size最大的儿子连一条重边,其它的连一条轻边。然后连出来大概是这样的:
我们把红色的叫做重链,其它的叫做轻边。很容易的可以看出来轻边的长度为1,每个点在且只在一条重链上。轻边把所有重链全部连在一起。然后我们可以证明我们跳不超过 logn l o g n 条重链可以跳到树根。(我最讨厌证明了。)不过还是理解一下,因为我们每次重边连size最大的点,所以我们每跳到重链链顶size就会至少乘2。可以想象,只有一颗完全二叉树才会把 log l o g 卡满,由此树链剖分是非常快的,因为它的 log l o g 基本不满。
在讲数据结构之前我们讲一讲怎么用树链剖分求LCA:
1.判断两个点是否在一条链上,若不同则把链顶深度深的点跳到链顶的父亲上
2.重复第一步
3.直到两个点在同一条链上,输出深度浅的点编号。
是不是很简单呢?
那么具体问题来了,我们怎么分链?怎么判断两个点是不是在一条链上?我们开一堆vector记录每一条的点么?
naive!!!
实际上我们并不需要直到整条链的整体信息,我们需要的操作其实就是判断两个点是否在一条链上。所以我们只用记录每个点所在重链的链顶。判断在一条链上只用判断它们的链顶是否相同就好。具体实现我们可以两遍dfs:
第一遍:记录 fa,deep,size f a , d e e p , s i z e
第二遍:记录 top t o p (每个点的链顶)
好,现在我们来学习树链剖分的数据结构部分:P3384 【模板】树链剖分
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
数据结构既视感,这种区间加区间求和啥的显然我们需要一颗线段树。所以我们先写线段树。
现在我们线段树写好了。
然后我们需要把树上结点按某种顺序放进线段树里。然后我们就需要刚刚的树链剖分。我们把节点按dfs序依次放进线段树里,大概像这样:
值得注意的是,由于我们每次先遍历重儿子,一条重链上的节点在线段树上一定是连续的一段。而且由于dfs序的性质,一颗子树在线段树里也是连续的一段。但是因为树上结点标号与线段树上节点不同,我们还要开一个数组表示树上每一个节点在线段树中的位置。
现在3,4操作很容易了对不对?我们先找到x在线段树中的位置(作为左端点),区间加/求和,区间长度为x的size。
1,2操作相对来说要复杂一些,不过也很简单(只是复杂度要复杂一些)。我们知道树上最短路就是两个节点到LCA的路程合起来。而我们刚刚学习求过LCA。我们就只用把跳过的节点做一次区间加/求和。因为同一条重链是连续的,所以是可以区间加/求和。注意就是只用跳重链的时候区间操作。最后在同一条链上后,我们把它们之间那一段也区间操作就行。
于是我们的复杂度就是 nlong2 n l o n g 2 的(LCA是 log l o g ,区间操作是 log l o g )
最后贴一波代码,稍微有点长,还有注意取模。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
struct lxy{
int next,to;
}b[200005];
struct Tree{
int lson,rson,l,r,tag,num;
}a[200005];
int n,m,root,rootx,mod,cnt;
int fa[100005],top[100005],size[100005],id[100005],weigh[100005],head[100005],deeeep[100005];
bool vis[100005];
int data[100005];
void update(int u)
{
a[u].num=(a[a[u].lson].num+a[a[u].rson].num)%mod;
}
void pushdown(int u)
{
if(a[u].tag==0) return;
a[a[u].lson].num=(a[a[u].lson].num+a[u].tag*(a[a[u].lson].r-a[a[u].lson].l+1))%mod;
a[a[u].lson].tag=(a[a[u].lson].tag+a[u].tag)%mod;
a[a[u].rson].num=(a[a[u].rson].num+a[u].tag*(a[a[u].rson].r-a[a[u].rson].l+1))%mod;
a[a[u].rson].tag=(a[a[u].rson].tag+a[u].tag)%mod;
a[u].tag=0;
}
void build(int &u,int l,int r)
{
u=++cnt;
a[u].l=l;a[u].r=r;
if(l==r) return;
int mid=(l+r)/2;
build(a[u].lson,l,mid);
build(a[u].rson,mid+1,r);
}
void modify(int u,int l,int r,int x)
{
pushdown(u);
if(a[u].l==l&&a[u].r==r)
{
a[u].num=(a[u].num+(r-l+1)*x)%mod;
a[u].tag=(a[u].tag+x)%mod;
return;
}
int mid=(a[u].l+a[u].r)/2;
if(l>mid) modify(a[u].rson,l,r,x);
else if(r<=mid) modify(a[u].lson,l,r,x);
else modify(a[u].lson,l,mid,x),modify(a[u].rson,mid+1,r,x);
update(u);
}
int ques(int u,int l,int r)
{
pushdown(u);
if(a[u].l==l&&a[u].r==r)
return a[u].num;
int mid=(a[u].l+a[u].r)/2;
if(l>mid) return ques(a[u].rson,l,r);
else if(r<=mid) return ques(a[u].lson,l,r);
else return (ques(a[u].lson,l,mid)+ques(a[u].rson,mid+1,r))%mod;
}
//-------线段树---------
void add(int op,int ed)
{
cnt++;
b[cnt].to=ed;
b[cnt].next=head[op];
head[op]=cnt;
}
void dfs1(int u,int dep)
{
vis[u]=1;
size[u]=1;
deeeep[u]=dep;
int p=0;
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0)
{
fa[b[i].to]=u;
dfs1(b[i].to,dep+1);
size[u]+=size[b[i].to];
if(p<size[b[i].to]) p=size[b[i].to],weigh[u]=b[i].to;
}
}
void dfs2(int u,int las)
{
vis[u]=1;
id[u]=++cnt;
top[u]=las;
modify(rootx,cnt,cnt,data[u]);//由于有初始值,dfs顺手加了
if(size[u]==1) return;
dfs2(weigh[u],las);
for(int i=head[u];i!=-1;i=b[i].next)
if(vis[b[i].to]==0)
dfs2(b[i].to,b[i].to);
}
int main()
{
scanf("%d%d%d%d",&n,&m,&root,&mod);
for(int i=1;i<=n;i++)scanf("%d",&data[i]);
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
cnt=0;build(rootx,1,n);
memset(vis,0,sizeof(vis));dfs1(root,1);
cnt=0;memset(vis,0,sizeof(vis));dfs2(root,root);
for(int i=1;i<=m;i++)
{
int q,x,y,z;
scanf("%d",&q);
if(q==1)
{
scanf("%d%d%d",&x,&y,&z);
while(top[x]!=top[y])//求LCA,很好理解的
{
if(deeeep[top[x]]<=deeeep[top[y]]) modify(rootx,id[top[y]],id[y],z),y=fa[top[y]];
else modify(rootx,id[top[x]],id[x],z),x=fa[top[x]];
}
if(deeeep[x]>=deeeep[y]) modify(rootx,id[y],id[x],z);
else modify(rootx,id[x],id[y],z);
}
if(q==2)
{
int ans=0;
scanf("%d%d",&x,&y);
while(top[x]!=top[y])
{
if(deeeep[top[x]]<=deeeep[top[y]]) ans=(ans+ques(rootx,id[top[y]],id[y]))%mod,y=fa[top[y]];
else ans=(ans+ques(rootx,id[top[x]],id[x]))%mod,x=fa[top[x]];
}
if(deeeep[x]>=deeeep[y]) ans=(ans+ques(rootx,id[y],id[x]))%mod;
else ans=(ans+ques(rootx,id[x],id[y]))%mod;
printf("%d\n",ans);
}
if(q==3) scanf("%d%d",&x,&y),modify(rootx,id[x],id[x]+size[x]-1,y);
if(q==4) scanf("%d",&x),printf("%d\n",ques(rootx,id[x],id[x]+size[x]-1));
}
}