191028-树链剖分
定义
指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。
那么什么是轻重边呢?
很简单,来看下面一张图:
在轻重链剖分中,每个非叶节点有且仅有一个重儿子,其余子节点为轻儿子。轻重儿子划分的依据是子树大小,对于每个非叶节点, 其子树size最大的子节点成为它的重儿子,其余节点成为它的轻儿子。
那么上图中红边即中边,黑边即轻边。
性质
(说明:重链可以是一条边,一条链,也可以是一个点)
应用
1,求LCA
代码如下
#include<bits/stdc++.h>
#define M 500006
using namespace std;
int first[M],to[M*2],nxt[M*2],f[M],dep[M],n,q,root,tot,size[M],num[M],top[M],idx[M],cnt,son[M];
void add(int x,int y)
{
nxt[++tot]=first[x];
first[x]=tot;
to[tot]=y;
}
void dfs1(int u,int fa)
{
dep[u]=dep[fa]+1;//深度
size[u]=1;//子树大小
f[u]=fa;//父亲
for(int i=first[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa) continue;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;//更新重儿子
}
}
void dfs2(int u,int tp)
{
top[u]=tp;//该点所在重链编号
num[u]=++cnt;//该点的dfs序
if(son[u]) dfs2(son[u],tp);//先遍历重儿子
for(int i=first[u];i;i=nxt[i])
{
int v=to[i];
if(!num[v]) dfs2(v,v);
}
}
int lca(int u,int v)
{
while(top[u]!=top[v])//如果两个点不在一条重链上
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);//top值更深的点跳
u=f[top[u]];
}
return dep[u]<dep[v]?u:v;//最后还要判断一下深度
}
int main()
{
int x,y;
scanf("%d%d%d",&n,&q,&root);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(root,0);
dfs2(root,root);
for(int i=1;i<=q;i++)
{
scanf("%d%d",&x,&y);
printf("%d\n",lca(x,y));
}
return 0;
}
2,路径信息维护
3,子树信息维护
…
例题
T1 树链剖分板子
解析
由题意,有四种操作,两种是对链操作,另外两种对子树操作
1,对链的操作
可以把链拆分成一些重链集合(即点+边+链),因为每次先遍历的重儿子,所以一条重链上的dfs序一定是连续的,因此直接在线段树进行维护即可
2,对子树的操作
子树就更为简单了,因为一棵子树的dfs序一定也是连续的即
n
u
m
[
x
]
num[x]
num[x]~
n
u
m
[
x
]
+
s
i
z
e
[
x
]
−
1
num[x]+size[x]-1
num[x]+size[x]−1,直接线段树维护即可
代码
#include<bits/stdc++.h>
#define int long long
#define M 200006
using namespace std;
int nxt[M*2],to[M*2],first[M],tot,f[M],size[M],son[M],dep[M],n,m,mod,rt,num[M],idx[M],top[M],cnt,a[M],vis;
struct node
{
int l,r,sum,add;
}tree[4*M];
int read()
{
int f=1,re=0;
char ch;
for(ch=getchar();!isdigit(ch)&&ch!='-';ch=getchar());
if(ch=='-')
{
f=-1;
ch=getchar();
}
for(;isdigit(ch);ch=getchar())
re=(re<<3)+(re<<1)+ch-'0';
return re*f;
}
void add1(int x,int y)
{
nxt[++tot]=first[x];
first[x]=tot;
to[tot]=y;
}
void dfs1(int u,int fa)
{
dep[u]=dep[fa]+1;
size[u]=1;
f[u]=fa;
for(int i=first[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa) continue;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
num[u]=++cnt;
idx[num[u]]=u;
if(son[u]) dfs2(son[u],tp);
for(int i=first[u];i;i=nxt[i])
{
int v=to[i];
if(!num[v]) dfs2(v,v);
}
}
void build(int k,int l,int r)
{
tree[k].l=l;
tree[k].r=r;
if(l==r)
{
tree[k].sum=a[idx[l]];
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
tree[k].sum=(tree[k<<1].sum+tree[k<<1|1].sum)%mod;
}
void add2(int k,int l,int r,int val)
{
tree[k].add+=val;
tree[k].add%=mod;
tree[k].sum+=(val*(r-l+1));
tree[k].sum%=mod;
}
void pushdown(int k,int l,int r,int mid)
{
if(tree[k].add==0) return;
add2(k<<1,l,mid,tree[k].add);
add2(k<<1|1,mid+1,r,tree[k].add);
tree[k].add=0;
}
void modify(int k,int l,int r,int val)
{
if(tree[k].l>=l&&tree[k].r<=r) return add2(k,tree[k].l,tree[k].r,val);
int mid=(tree[k].l+tree[k].r)>>1;
pushdown(k,tree[k].l,tree[k].r,mid);
if(l<=mid) modify(k<<1,l,r,val);
if(r>mid) modify(k<<1|1,l,r,val);
tree[k].sum=(tree[k<<1].sum+tree[k<<1|1].sum)%mod;
}
int solve(int k,int l,int r)
{
if(tree[k].l>=l&&tree[k].r<=r) return tree[k].sum%mod;
int mid=(tree[k].l+tree[k].r)>>1;
pushdown(k,tree[k].l,tree[k].r,mid);
int ret=0;
if(l<=mid) ret+=solve(k<<1,l,r);
ret%=mod;
if(r>mid) ret+=solve(k<<1|1,l,r);
return ret%mod;
}
void chain_add(int x,int y,int z)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
modify(1,num[top[x]],num[x],z);
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
modify(1,num[x],num[y],z);
}
int chain_solve(int x,int y)
{
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=solve(1,num[top[x]],num[x]);
ans%=mod;
x=f[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=solve(1,num[x],num[y]);
ans%=mod;
return ans;
}
void tree_add(int x,int y)
{
modify(1,num[x],num[x]+size[x]-1,y);
}
int tree_solve(int x)
{
return solve(1,num[x],num[x]+size[x]-1)%mod;
}
void debug(int k){//查错函数
if(tree[k].l==tree[k].r){
cout<<tree[k].sum<<" ";
return ;
}
debug(k<<1);debug(k<<1|1);
}
signed main()
{
int x,y,z;
scanf("%lld%lld%lld%lld",&n,&m,&rt,&mod);
for(int i=1;i<=n;i++) a[i]=read();
for(int i=1;i<n;i++)
{
x=read();
y=read();
add1(x,y);
add1(y,x);
}
dfs1(rt,0);
dfs2(rt,rt);
build(1,1,n);
for(int i=1;i<=m;i++)
{
vis=read();
if(vis==1)
{
x=read();
y=read();
z=read();
chain_add(x,y,z);
}
if(vis==2)
{
x=read();
y=read();
printf("%lld\n",chain_solve(x,y)%mod);
}
if(vis==3)
{
x=read();
y=read();
tree_add(x,y);
}
if(vis==4)
{
x=read();
printf("%lld\n",tree_solve(x)%mod);
}
}
return 0;
}