树链剖分概述

说一下树链剖分,一开始听这个名字感觉像是树上结构一样。结果是个数据结构。要理解这个数据结构我们先来了解几个概念:

树链剖分这里我们主要运用的是重链剖分。也就是说我们记录每个节点代表子树的size。每个节点向它size最大的儿子连一条重边,其它的连一条轻边。然后连出来大概是这样的:

这里写图片描述

我们把红色的叫做重链,其它的叫做轻边。很容易的可以看出来轻边的长度为1,每个点在且只在一条重链上。轻边把所有重链全部连在一起。然后我们可以证明我们跳不超过 logn l o g n 条重链可以跳到树根。(我最讨厌证明了。)不过还是理解一下,因为我们每次重边连size最大的点,所以我们每跳到重链链顶size就会至少乘2。可以想象,只有一颗完全二叉树才会把 log l o g 卡满,由此树链剖分是非常快的,因为它的 log l o g 基本不满。

在讲数据结构之前我们讲一讲怎么用树链剖分求LCA:

1.判断两个点是否在一条链上,若不同则把链顶深度深的点跳到链顶的父亲上

2.重复第一步

3.直到两个点在同一条链上,输出深度浅的点编号。

是不是很简单呢?

那么具体问题来了,我们怎么分链?怎么判断两个点是不是在一条链上?我们开一堆vector记录每一条的点么?

naive!!!

实际上我们并不需要直到整条链的整体信息,我们需要的操作其实就是判断两个点是否在一条链上。所以我们只用记录每个点所在重链的链顶。判断在一条链上只用判断它们的链顶是否相同就好。具体实现我们可以两遍dfs:

第一遍:记录 fadeepsize f a , d e e p , s i z e

第二遍:记录 top t o p (每个点的链顶)

好,现在我们来学习树链剖分的数据结构部分:P3384 【模板】树链剖分

如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:

操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和

数据结构既视感,这种区间加区间求和啥的显然我们需要一颗线段树。所以我们先写线段树。

现在我们线段树写好了。

然后我们需要把树上结点按某种顺序放进线段树里。然后我们就需要刚刚的树链剖分。我们把节点按dfs序依次放进线段树里,大概像这样:

这里写图片描述

值得注意的是,由于我们每次先遍历重儿子,一条重链上的节点在线段树上一定是连续的一段。而且由于dfs序的性质,一颗子树在线段树里也是连续的一段。但是因为树上结点标号与线段树上节点不同,我们还要开一个数组表示树上每一个节点在线段树中的位置。

现在3,4操作很容易了对不对?我们先找到x在线段树中的位置(作为左端点),区间加/求和,区间长度为x的size。

1,2操作相对来说要复杂一些,不过也很简单(只是复杂度要复杂一些)。我们知道树上最短路就是两个节点到LCA的路程合起来。而我们刚刚学习求过LCA。我们就只用把跳过的节点做一次区间加/求和。因为同一条重链是连续的,所以是可以区间加/求和。注意就是只用跳重链的时候区间操作。最后在同一条链上后,我们把它们之间那一段也区间操作就行。

于是我们的复杂度就是 nlong2 n l o n g 2 的(LCA是 log l o g ,区间操作是 log l o g

最后贴一波代码,稍微有点长,还有注意取模。

#include<iostream>  
#include<cstdio>  
#include<cstring>  
using namespace std;  
struct lxy{  
    int next,to;  
}b[200005];  
struct Tree{  
    int lson,rson,l,r,tag,num;  
}a[200005];  

int n,m,root,rootx,mod,cnt;  
int fa[100005],top[100005],size[100005],id[100005],weigh[100005],head[100005],deeeep[100005];  
bool vis[100005];  
int data[100005];  

void update(int u)  
{  
    a[u].num=(a[a[u].lson].num+a[a[u].rson].num)%mod;  
}  

void pushdown(int u)  
{  
    if(a[u].tag==0) return;  
    a[a[u].lson].num=(a[a[u].lson].num+a[u].tag*(a[a[u].lson].r-a[a[u].lson].l+1))%mod;  
    a[a[u].lson].tag=(a[a[u].lson].tag+a[u].tag)%mod;  
    a[a[u].rson].num=(a[a[u].rson].num+a[u].tag*(a[a[u].rson].r-a[a[u].rson].l+1))%mod;  
    a[a[u].rson].tag=(a[a[u].rson].tag+a[u].tag)%mod;  
    a[u].tag=0;  
}  

void build(int &u,int l,int r)  
{  
    u=++cnt;  
    a[u].l=l;a[u].r=r;  
    if(l==r) return;  
    int mid=(l+r)/2;  
    build(a[u].lson,l,mid);  
    build(a[u].rson,mid+1,r);  
}  

void modify(int u,int l,int r,int x)  
{  
    pushdown(u);  
    if(a[u].l==l&&a[u].r==r)  
    {  
        a[u].num=(a[u].num+(r-l+1)*x)%mod;  
        a[u].tag=(a[u].tag+x)%mod;  
        return;  
    }  
    int mid=(a[u].l+a[u].r)/2;  
    if(l>mid) modify(a[u].rson,l,r,x);  
    else if(r<=mid) modify(a[u].lson,l,r,x);  
    else modify(a[u].lson,l,mid,x),modify(a[u].rson,mid+1,r,x);  
    update(u);  
}  

int ques(int u,int l,int r)  
{  
    pushdown(u);  
    if(a[u].l==l&&a[u].r==r)  
      return a[u].num;  
    int mid=(a[u].l+a[u].r)/2;  
    if(l>mid) return ques(a[u].rson,l,r);  
    else if(r<=mid) return ques(a[u].lson,l,r);  
    else return (ques(a[u].lson,l,mid)+ques(a[u].rson,mid+1,r))%mod;  
}  

//-------线段树---------   

void add(int op,int ed)  
{  
    cnt++;  
    b[cnt].to=ed;  
    b[cnt].next=head[op];  
    head[op]=cnt;  
}  

void dfs1(int u,int dep)  
{  
    vis[u]=1;  
    size[u]=1;  
    deeeep[u]=dep;  
    int p=0;  
    for(int i=head[u];i!=-1;i=b[i].next)  
      if(vis[b[i].to]==0)  
      {  
        fa[b[i].to]=u;  
        dfs1(b[i].to,dep+1);  
        size[u]+=size[b[i].to];  
        if(p<size[b[i].to]) p=size[b[i].to],weigh[u]=b[i].to;  
      }  
}  

void dfs2(int u,int las)  
{  
    vis[u]=1;  
    id[u]=++cnt;  
    top[u]=las;  
    modify(rootx,cnt,cnt,data[u]);//由于有初始值,dfs顺手加了  
    if(size[u]==1) return;  
    dfs2(weigh[u],las);  
    for(int i=head[u];i!=-1;i=b[i].next)  
      if(vis[b[i].to]==0)  
        dfs2(b[i].to,b[i].to);  
}  

int main()  
{  
    scanf("%d%d%d%d",&n,&m,&root,&mod);  
    for(int i=1;i<=n;i++)scanf("%d",&data[i]);  
    memset(head,-1,sizeof(head));  
    for(int i=1;i<n;i++)  
    {  
        int x,y;  
        scanf("%d%d",&x,&y);  
        add(x,y);add(y,x);  
    }  
    cnt=0;build(rootx,1,n);  
    memset(vis,0,sizeof(vis));dfs1(root,1);  
    cnt=0;memset(vis,0,sizeof(vis));dfs2(root,root);  
    for(int i=1;i<=m;i++)  
    {  
        int q,x,y,z;  
        scanf("%d",&q);  
        if(q==1)  
        {  
            scanf("%d%d%d",&x,&y,&z);  
            while(top[x]!=top[y])//求LCA,很好理解的  
            {  
                if(deeeep[top[x]]<=deeeep[top[y]]) modify(rootx,id[top[y]],id[y],z),y=fa[top[y]];  
                else modify(rootx,id[top[x]],id[x],z),x=fa[top[x]];  
            }  
            if(deeeep[x]>=deeeep[y]) modify(rootx,id[y],id[x],z);  
            else modify(rootx,id[x],id[y],z);  
        }  
        if(q==2)  
        {  
            int ans=0;  
            scanf("%d%d",&x,&y);  
            while(top[x]!=top[y])  
            {  
                if(deeeep[top[x]]<=deeeep[top[y]]) ans=(ans+ques(rootx,id[top[y]],id[y]))%mod,y=fa[top[y]];  
                else ans=(ans+ques(rootx,id[top[x]],id[x]))%mod,x=fa[top[x]];  
            }  
            if(deeeep[x]>=deeeep[y]) ans=(ans+ques(rootx,id[y],id[x]))%mod;  
            else ans=(ans+ques(rootx,id[x],id[y]))%mod;  
            printf("%d\n",ans);  
        }  
        if(q==3) scanf("%d%d",&x,&y),modify(rootx,id[x],id[x]+size[x]-1,y);  
        if(q==4) scanf("%d",&x),printf("%d\n",ques(rootx,id[x],id[x]+size[x]-1));  
    }  
}  
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值