树链剖分基础模板(BZOJ1036[ZJOI2008]树的统计Count)

摘自XZY的博客

1. 前言

如果给你一棵树,求点u到点v路径上点的权值之和,你可能会说:倍增啊!

那如果出题人:我还要你支持修改某个点的权值!

或者再j一点:我还要你支持修改点u到点v路径上点的权值!

那就得用树链剖分了。

2. 什么是树链剖分

上面那个问题,树上区间修改。

区间修改最常见做法就是线段树了。

那我们怎么用线段树维护一颗。。。普通的树呢?

那就给普通的树的每个节点标个号,然后放线段树里呗。区间维护。

但如果随便标号,那点u到点v路径不一定标号是连续的啊,你线段树维护个j啊

所以我们现在引入一个(堆)姿势:



那么这些姿势有什么用呢?先看一道题吧!

传送门= ̄ω ̄=

我们先来看张图:

图中标在边上了,但也不影响我们学。。。

标号方法是:跑dfs,先给当前节点标号,再给重儿子标号(重儿子和当前节点在一个重链上),然后对重儿子递归,最后给剩下的别的儿子标号(别的儿子不和当前节点在一个重链上,所以新建重链,把新建的重链的顶端节点设为那个”别的儿子“)、递归(图中是先给重(zhong)边标号,再给剩下的边标号)。标号从小到大。

不难发现一条重链上的标号是连续的,比如点1到点14,点2到点12

这意味着在线段树中,它们是在一个连续的区间里的,而不是像随便标号时断断续续的。

这样就很好用线段树处理了。

如果两个节点不在一条重链上呢?比如图中的点11和点10,我们要求它们之间的路径上的点权和

那我们就看,点11所在的重链是11->6->2,点10所在的重链是10(10所在的重链只有一个点,就是10)。所以我们就先求出重链11->6->2上的点权和、重链10上的点权和。这两条重链在线段树上都是一段连续的区间,可以直接log2n求出

这时候我们发现还有4->1的重链没有计算,就把它的点权和计算出来,三个重链的点权和加在一起就得到了答案。

所以我们要记录的是:
1. pos[i] 点i的标号
2. top[i] 点i所在重链的顶端节点
3. siz[i] 以点i为根的子树的大小
4. dep[i] 点i的深度
5. fa[i] 点i的父亲节点

我们先跑一边dfs,算出fa、size、dep

然后再跑一边dfs,根据size[i]找出点i的重儿子,然后算出pos、top。

搞完这些就很easy了,因为一段重链在线段树里是一段连续的区间(这是坠重要的)。

我们在查询/修改从点u到点v的路径时,先找到所在重链的顶端节点(top)深度较深的(因为这样能让u和v同步提升,防止一个提到根节点了,另一个还没提,这时候你就不知道提谁了),注意不能按照u和v的深度来提!比如top较深的点是u,然后就用线段树处理区间[pos[top[u]],pos[u]](因为top[u]的标号一定比u要小),再设置u为fa[top[u]],把u往上提,直至u和v在一条重链上(即top[u]==top[v])。这时候可能u和v之间还有一段距离,此时u和v已经在一条重链上,直接处理它们之间的区间就行了。

然后复杂度就是:O(Nlog_{2} N+Qlog^{2} N)

同时这个复杂度也是一般的树链剖分的复杂度。因为重链个数不会超过log_{2} N个,线段树复杂度是log_{2} N的。网上有证明,我就不做过多赘述

 

然后不要脸的放上我的代码。。

/*
siz[]数组,用来保存以x为根的子树节点个数
top[]数组,用来保存当前节点的所在链的顶端节点
son[]数组,用来保存重儿子
dep[]数组,用来保存当前节点的深度
fa[]数组,用来保存当前节点的父亲
pos[]数组,用来保存树中每个节点剖分后的新编号
rank[]数组,用来保存当前节点在线段树中的位置
*/
#pragma GCC optimize("O2")
#include<bits/stdc++.h>
#define maxn 50000
#define ls (rt<<1)
#define rs (rt<<1|1)
using namespace std;
 
//init begin
struct TREE
{
    int l,r,sum,max;
};
TREE tr[maxn<<2];
char opt[200];
int pos[maxn],fa[maxn],sz[maxn],w[maxn],deep[maxn],q,n,m,cnt=0,top[maxn],son[maxn],rank[maxn];
vector<int> g[maxn];
//init end
 
//IntervalTree begin
 
void pushup(int rt)
{
    tr[rt].sum=tr[ls].sum+tr[rs].sum;
    tr[rt].max=max(tr[ls].max,tr[rs].max);
    return ;
}
 
void build(int l,int r,int rt)
{
    tr[rt].l=l,tr[rt].r=r;
    if(l==r)
        {tr[rt].sum=tr[rt].max=w[l];return ;}
    int mid=l+(r-l)/2;
    build(l,mid,ls),build(mid+1,r,rs);
    pushup(rt);
}
 
void update(int l,int c,int rt)
{
    if(tr[rt].l==tr[rt].r)
    {
        tr[rt].sum=tr[rt].max=c;
        return ;
    }
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2;
    if(l<=mid) update(l,c,ls);
    else update(l,c,rs);
    pushup(rt);
}
 
int query_max(int l,int r,int rt)
{
    if(l<=tr[rt].l&&tr[rt].r<=r) return tr[rt].max;
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2,ans=INT_MIN;
    if(l<=mid) ans=max(ans,query_max(l,r,ls));
    if(r>mid) ans=max(ans,query_max(l,r,rs));
    return ans;
}
 
int query_sum(int l,int r,int rt)
{
    if(l<=tr[rt].l&&tr[rt].r<=r) return tr[rt].sum;
    int mid=tr[rt].l+(tr[rt].r-tr[rt].l)/2,ans=0;
    if(l<=mid) ans+=query_sum(l,r,ls);
    if(r>mid) ans+=query_sum(l,r,rs);
    return ans;
}
 
//IntervalTree end
 
void dfs1(int x,int fat,int d)
{
    sz[x]=1;deep[x]=d;fa[x]=fat;
    for(int i=0;i<g[x].size();i++)
        if(fa[x]!=g[x][i])
        {
            dfs1(g[x][i],x,d+1);
			sz[x]+=sz[g[x][i]];
			if(son[x]==-1||sz[g[x][i]]>sz[son[x]])
				son[x]=g[x][i];
        }
}
 
void dfs2(int x,int tp)
{
    top[x]=tp,pos[x]=++cnt,rank[pos[x]]=x;
    if(son[x]==-1) return ;
    dfs2(son[x],tp);
    for(int i=0;i<g[x].size();i++)
        if(g[x][i]!=fa[x]&&g[x][i]!=son[x])
            dfs2(g[x][i],g[x][i]);
}
 
 
int lca(int a,int b,int ok)
{
    int ans=ok?INT_MIN:0;
    while(top[a]!=top[b])
    {
        if(deep[top[a]]>deep[top[b]]) swap(a,b);
        if(!ok) ans+=query_sum(pos[top[b]],pos[b],1);
        else ans=max(ans,query_max(pos[top[b]],pos[b],1));
        b=fa[top[b]];
    }
    if(deep[a]>deep[b]) swap(a,b);
    if(!ok) ans+=query_sum(pos[a],pos[b],1);
    else ans=max(ans,query_max(pos[a],pos[b],1));
    return ans;
}
 
int main()
{  
    int u,v;
    memset(son,-1,sizeof(son));
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        scanf("%d%d",&u,&v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1,-1,1),top[1]=1,dfs2(1,1);
    for(int i=1;i<=n;i++)
        scanf("%d",&w[pos[i]]);
    build(1,cnt,1);
    scanf("%d",&q);
    for(int i=1;i<=q;i++)
    {
        scanf("%s%d%d",opt,&u,&v);
        if(opt[1]=='M')printf("%d\n",lca(u,v,1));
        else if(opt[1]=='H') update(pos[u],v,1);
        else if(opt[1]=='S') printf("%d\n",lca(u,v,0));
    }
    return 0;
}

  

  

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
#include <cstdio> #include <iostream> #include <vector> #define N 30003 #define INF 2147483647 using namespace std; int n,f[N][20],dep[N],siz[N],son[N],top[N],tot,pos[N],w[N]; int Max[N*4],Sum[N*4]; vector <int> to[N]; void dfs1(int x){ siz[x]=1; int sz=to[x].size(); for(int i=0;i<sz;++i){ int y=to[x][i]; if(y==f[x][0])continue; f[y][0]=x; dep[y]=dep[x]+1; dfs1(y); siz[x]+=siz[y]; if(siz[y]>siz[son[x]])son[x]=y; } } void dfs2(int x,int root){ top[x]=root; pos[x]=++tot; if(son[x])dfs2(son[x],root); int sz=to[x].size(); for(int i=0;i<sz;++i){ int y=to[x][i]; if(y==f[x][0] || y==son[x])continue; dfs2(y,y); } } void update(int k,int l,int r,int P,int V){ if(l==r){ Max[k]=Sum[k]=V; return; } int mid=(l+r)>>1; if(P<=mid)update(k*2,l,mid,P,V); else update(k*2+1,mid+1,r,P,V); Max[k]=max(Max[k*2],Max[k*2+1]); Sum[k]=Sum[k*2]+Sum[k*2+1]; } void up(int &x,int goal){ for(int i=15;i>=0;--i) if(dep[f[x][i]]>=goal)x=f[x][i]; } int lca(int x,int y){ if(dep[x]>dep[y])up(x,dep[y]); if(dep[x]<dep[y])up(y,dep[x]); if(x==y)return x; for(int i=15;i>=0;--i) if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i]; return f[x][0]; } int getm(int k,int l,int r,int L,int R){ if(L<=l && r<=R)return Max[k]; int res=-INF,mid=(l+r)>>1; if(L<=mid)res=max(res,getm(k*2,l,mid,L,R)); if(R>mid)res=max(res,getm(k*2+1,mid+1,r,L,R)); return res; } int gets(int k,int l,int r,int L,int R){ if(L<=l && r<=R)return Sum[k]; int res=0,mid=(l+r)>>1; if(L<=mid)res+=gets(k*2,l,mid,L,R); if(R>mid)res+=gets(k*2+1,mid+1,r,L,R); return res; } int main(){ scanf("%d",&n); for(int i=1,a,b;i<n;++i){ scanf("%d%d",&a,&b); to[a].push_back(b); to[b].push_back(a); } dep[1]=1; dfs1(1); dfs2(1,1); for(int i=1;i<=15;++i) for(int j=1;j<=n;++j)f[j][i]=f[f[j][i-1]][i-1]; for(int i=1;i<=n;++i){ scanf("%d",&w[i]); update(1,1,n,pos[i],w[i]); } int q; scanf("%d",&q); while(q--){ char s[10]; int u,v,t; scanf("%s",s); if(s[1]=='H'){ scanf("%d%d",&u,&t); w[u]=t; update(1,1,n,pos[u],t); } if(s[1]=='M'){ scanf("%d%d",&u,&v); int ans=-INF,t=lca(u,v); for(int i=u;i;i=f[top[i]][0]) if(dep[t]<dep[top[i]]) ans=max(ans,getm(1,1,n,pos[top[i]],pos[i])); else{ ans=max(ans,getm(1,1,n,pos[t],pos[i])); break; } for(int i=v;i;i=f[top[i]][0]) if(dep[t]<dep[top[i]]) ans=max(ans,getm(1,1,n,pos[top[i]],pos[i])); else{ ans=max(ans,getm(1,1,n,pos[t],pos[i])); break; } printf("%d\n",ans); } if(s[1]=='S'){ scanf("%d%d",&u,&v); int ans=0,t=lca(u,v); for(int i=u;i;i=f[top[i]][0]) if(dep[t]<dep[top[i]]) ans+=gets(1,1,n,pos[top[i]],pos[i]); else{ ans+=gets(1,1,n,pos[t],pos[i]); break; } for(int i=v;i;i=f[top[i]][0]) if(dep[t]<dep[top[i]]) ans+=gets(1,1,n,pos[top[i]],pos[i]); else{ ans+=gets(1,1,n,pos[t],pos[i]); break; } printf("%d\n",ans-w[t]); } } }

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值