前言
首先,在学树链剖分之前要先把 LCA、树形DP、DFS序 这三个知识点学了
还有必备的 链式前向星、线段树 也要先学了。
树链剖分 就是对一棵树分成几条链,把树形变为线性,减少处理难度
需要处理的问题都有:
- 将树从x到y结点最短路径上所有节点的值都加上z
- 求树从x到y结点最短路径上所有节点的值之和
- 将以x为根节点的子树内所有节点值都加上z
- 求以x为根节点的子树内所有节点值之和
原题
》》题目
概念
思想:把一棵树拆成若干个不相交的链,然后用一些数据结构去维护这些链
定义:
- 重儿子:该节点的子树中,节点个数最多的子树的根节点(也就是和该节点相连的点),即为该节点的重儿子
- 轻儿子:对于每一个非叶子节点,它的儿子中 非重儿子 的剩下所有儿子即为轻儿子
叶子节点没有重儿子也没有轻儿子(因为它没有儿子。。) - 重边:一个父亲连接他的重儿子的边称为重边
- 轻边:剩下的即为轻边
- 重链:相邻重边连起来的 链叫重链
那么如何对这些链进行维护?
首先,要对这些链进行维护,就要确保每个链上的节点都是连续的,
因此我们需要对整棵树进行重新编号,然后利用dfs序的思想,用线段树或树状数组等进行维护。
注意在进行重新编号的时候先访问重链,这样可以保证重链内的节点编号连续
结合一张图来理解一下:
对于一棵最基本的树,给他标记重儿子,
(蓝色为重儿子,红色为重边然后对树进行重新编号)
橙色表示的是该节点重新编号后的序号,不难看出重链内的节点编号是连续的,然后就可以在线段树上搞事情啦!!像什么区间加区间求和什么的
另外有一个性质:以i为根的子树的树在线段树上的编号为 [ i , i + 子 树 节 点 数 − 1 ] [i,i+子树节点数-1] [i,i+子树节点数−1]
接下来结合例题,加深一下对于代码的理解
Part One
首先来一坨定义
int dep[MAXN];//节点的深度
int f[MAXN];//节点的父亲
int son[MAXN];//节点的重儿子
int tot[MAXN];//节点子树的大小
(1)按照我们上面说的,我们首先要对整棵树dfs一遍,找出每个节点的重儿子,顺便处理出每个节点的深度,以及他们的父亲节点
void dfs1(int x,int fa){//x当前节点,fa父亲
dep[x]=dep[fa]+1;//标记每个点的深度
f[x]=fa;//标记每个点的父亲
tot[x]=1;//标记每个非叶子节点的子树大小
int maxn=-1;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;//若为父亲则continue
dfs1(y,x);
tot[x]+=tot[y];//把它的儿子数加到它身上
if(tot[y]>maxn)
{
maxn=tot[y];
son[x]=y;//标记每个非叶子节点的重儿子编号
}
}
}
(2)然后我们需要对整棵树进行重新编号(我把一开始的每个节点的权值存在了w数组内)
void dfs2(int x,int topf){//x当前节点,topf当前链的最顶端的节点
id[x]=++cnt;//重新编号后该节点的编号是多少
ww[cnt]=w[x];//把每个点的初始值赋到新编号上来
top[x]=topf;//记录这个点所在链的顶端
if(!son[x])return;//如果没有儿子则返回
dfs2(son[x],topf);//按先处理重儿子,再处理轻儿子的顺序递归处理
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(!id[y])
dfs2(y,y);//对于每一个轻儿子都有一条从它自己开始的链
}
}
Part Tow
线段树维护:
我们需要根据重新编完号的树,把这棵树的上每个点映射到线段树上
void build(int k,int l,int r)
{
t[k].x=l,t[k].y=r,t[k].size=r-l+1;
if(l==r)
{
t[k].w=ww[l];
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
up(k);
}
另外线段树的基本操作, 这里就不详细解释了,直接放代码
void up(int k){//更新
t[k].w=(t[k*2].w+t[k*2+1].w)%p;
}
void pushdown(int k){//下传标记
if(!t[k].add)return;
t[k*2].add=(t[k*2].add+t[k].add)%p;
t[k*2+1].add=(t[k*2+1].add+t[k].add)%p;
t[k*2].w=(t[k*2].w+t[k*2].size*t[k].add)%p;
t[k*2+1].w=(t[k*2+1].w+t[k*2+1].size*t[k].add)%p;
t[k].add=0;
}
int pushsum(int k,int l,int r){//区间求和
int ans=0;
if(t[k].x>=l&&t[k].y<=r)
return t[k].w;
int mid=(t[k].x+t[k].y)/2;
pushdown(k);
if(l<=mid)ans=(ans+pushsum(k*2,l,r))%p;
if(r>mid)ans=(ans+pushsum(k*2+1,l,r))%p;
return ans;
}
void pushadd(int k,int l,int r,int v){//区间加
if(t[k].x>=l&&t[k].y<=r)
{
t[k].w+=t[k].size*v;
t[k].add+=v;
return;
}
int mid=(t[k].x+t[k].y)/2;
pushdown(k);
if(l<=mid)pushadd(k*2,l,r,v);
if(r>mid)pushadd(k*2+1,l,r,v);
up(k);
}
Part Three
我们考虑如何实现对于树上的操作:
树链剖分的思想是:对于两个不在同一重链内的节点,让他们不断地跳,使得他们处于同一重链上
那么如何"跳”呢?还记得我们在第二次dfs中记录的top数组么?
有一个显然的结论:x到top[x]中的节点在线段树上是连续的。
结合dep数组,假设两个节点为x,y
我们每次让 d e e p [ t o p [ x ] ] deep[top[x]] deep[top[x]]与 d e e p [ t o p [ y ] ] deep[top[y]] deep[top[y]]中大的(在下面的)往上跳(有点类似于树上倍增)让x节点直接跳到 t o p [ x ] top[x] top[x]的前面一个,然后在线段树上更新。
最后两个节点一定是处于同一条重链的,前面我们提到过重链上的节点都是连续的,最后在线段树上进行一次查询就好
void Treeadd(int x,int y,int v){//对于x,y路径上的点加val的权值
while(top[x]!=top[y])//当两个点不在同一条链上
{
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点
pushadd(1,id[top[x]],id[x],v);//x点到x所在链顶端 这一段区间的点权更新
x=f[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点
pushadd(1,id[x],id[y],v);//两个点间的区间点权更新
}
int Treesum(int x,int y){//x与y路径上的权值和
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);//把x点改为所在链顶端的深度更深的那个点
ans=(ans+pushsum(1,id[top[x]],id[x]))%p;//ans加上x点到x所在链顶端 这一段区间的点权和
x=f[top[x]];//把x跳到x所在链顶端的那个点的上面一个点
}
//直到两个点处于一条链上
if(dep[x]>dep[y])swap(x,y);//把x点深度更深的那个点
ans=(ans+pushsum(1,id[x],id[y]))%p;//这时再加上此时两个点间的区间和即可
return ans;
}
pushadd(1,id[x],id[x]+tot[x]-1,~);//处理一点及其子树的点权和
在树上查询的这一步可能有些抽象,我们结合一个例子来理解一下
还是上面那张图,假设我们要查询3.63.6这两个节点的之间的点权合,为了方便理解我们假设每个点的点权都是11
刚开始时
t o p [ 3 ] = 2 , t o p [ 6 ] = 1 top[3]=2,top[6]=1 top[3]=2,top[6]=1
d e e p [ t o p [ 3 ] ] = 2 , d e e p [ t o p [ 6 ] ] = 1 deep[top[3]]=2,deep[top[6]]=1 deep[top[3]]=2,deep[top[6]]=1
我们会让
3
3
3向上跳,跳到
t
o
p
[
3
]
top[3]
top[3]的爸爸,也就是
1
1
1号节点
跳完后1号节点和6号节点已经在同一条重链内,所以直接对线段树进行一次查询即可
时间复杂度
性质1
如果边
(
u
,
v
)
\left( u,v\right)
(u,v),为轻边,那么
S
i
z
e
(
v
)
≤
S
i
z
e
(
u
)
/
2
Size\left( v\right) \leq Size\left( u\right) /2
Size(v)≤Size(u)/2。
证明:显然,否则该边会成为重边
性质2
树中任意两个节点之间的路径中轻边的条数不会超过
log
2
n
\log _{2}n
log2n,重路径的数目不会超过
log
2
n
\log _{2}n
log2n
证明:不会
Code
#include<iostream>
#include<cmath>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define ll long long
using namespace std;
ll p;
int n,m,root,k,x,y,cnt,z,s;
int w[2000010],head[2000010],dep[2000010],f[2000010],tot[2000010],son[2000010];
int id[2000010],ww[2000010],top[2000010];
struct c{
int x,next;
}a[2000100];
struct cc{
int x,y,l,add,size,w;
}t[2000100];
void add(int x,int y){
a[++k].x=y;
a[k].next=head[x];
head[x]=k;
}
void dfs1(int x,int fa){
dep[x]=dep[fa]+1;
f[x]=fa;
tot[x]=1;
int maxn=-1;
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(y==fa)continue;
dfs1(y,x);
tot[x]+=tot[y];
if(tot[y]>maxn)
{
maxn=tot[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++cnt;
ww[cnt]=w[x];
top[x]=topf;
if(!son[x])return;
dfs2(son[x],topf);
for(int i=head[x];i;i=a[i].next)
{
int y=a[i].x;
if(!id[y])
dfs2(y,y);
}
}
void up(int k){
t[k].w=(t[k*2].w+t[k*2+1].w)%p;
}
void pushdown(int k){
if(!t[k].add)return;
t[k*2].add=(t[k*2].add+t[k].add)%p;
t[k*2+1].add=(t[k*2+1].add+t[k].add)%p;
t[k*2].w=(t[k*2].w+t[k*2].size*t[k].add)%p;
t[k*2+1].w=(t[k*2+1].w+t[k*2+1].size*t[k].add)%p;
t[k].add=0;
}
void build(int k,int l,int r)
{
t[k].x=l,t[k].y=r,t[k].size=r-l+1;
if(l==r)
{
t[k].w=ww[l];
return;
}
int mid=(l+r)>>1;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
up(k);
}
void pushadd(int k,int l,int r,int v){
if(t[k].x>=l&&t[k].y<=r)
{
t[k].w+=t[k].size*v;
t[k].add+=v;
return;
}
int mid=(t[k].x+t[k].y)/2;
pushdown(k);
if(l<=mid)pushadd(k*2,l,r,v);
if(r>mid)pushadd(k*2+1,l,r,v);
up(k);
}
void Treeadd(int x,int y,int v){
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
pushadd(1,id[top[x]],id[x],v);
x=f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
pushadd(1,id[x],id[y],v);
}
int pushsum(int k,int l,int r){
int ans=0;
if(t[k].x>=l&&t[k].y<=r)
return t[k].w;
int mid=(t[k].x+t[k].y)/2;
pushdown(k);
if(l<=mid)ans=(ans+pushsum(k*2,l,r))%p;
if(r>mid)ans=(ans+pushsum(k*2+1,l,r))%p;
return ans;
}
int Treesum(int x,int y){
int ans=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=(ans+pushsum(1,id[top[x]],id[x]))%p;
x=f[top[x]];
}
if(dep[x]>dep[y])swap(x,y);
ans=(ans+pushsum(1,id[x],id[y]))%p;
return ans;
}
int main(){
scanf("%d%d%d%d",&n,&m,&root,&p);
for(int i=1;i<=n;i++)
scanf("%d",&w[i]);
for(int i=1;i<=n-1;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs1(root,0);
dfs2(root,root);
build(1,1,n);
for(int i=1;i<=m;i++)
{
scanf("%d",&s);
if(s==1)
{
scanf("%d%d%d",&x,&y,&z);
Treeadd(x,y,z%p);
}
if(s==2)
{
scanf("%d%d",&x,&y);
printf("%d\n",Treesum(x,y));
}
if(s==3)
{
scanf("%d%d",&x,&z);
pushadd(1,id[x],id[x]+tot[x]-1,z%p);
}
if(s==4)
{
scanf("%d",&x);
printf("%d\n",pushsum(1,id[x],id[x]+tot[x]-1));
}
}
}
/*
8 10 2 448348
458 718 447 857 633 264 238 944
1 2
2 3
3 4
6 2
1 5
5 7
8 6
3 7 611
4 6
3 1 267
3 2 111
1 6 3 153
3 7 673
4 8
2 6 1
4 7
3 4 228
1208
1055
2346
1900
*/