Q:什么是树链剖分?
A:树链剖分,计算机术语,指一种对树进行划分的算法,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。(摘自360百科)
Q:具体的一些概念呢?
A:
重儿子(节点):子树结点数目最多的结点
轻儿子(节点):除了重儿子以外的结点
重链:由多条重边连接而成的路径
轻链:由多条轻边连接而成的路径
Q:大概的步骤呢?
A:核心部分是预处理跑两次dfs
第一遍dfs:求出树每个结点的深度deep[x],其为根的子树大小size[x],以及祖先的信息fa[x]表示x的父亲。
第二遍dfs:根节点为起点,向下拓展构建重链,选择最大的一个子树的根继承当前重链,其余节点,都以那个节点为起点向下重新拉一条重链,给每个结点分配一个如同dfs序的编号,然后记录其所在重链的顶端节点,每条重链就相当于一段区间,用数据结构去维护。
最后就是把所有的重链首尾相接(通过dfs序),放到同一个数据结构上,然后维护这一个整体即可。(还不会的可以往下翻看代码,有注解)
这里放一道例题更好的了解树链剖分
题目描述
没过多的题目背景,现在给你一颗有n个结点的树,有n-1条边相连,每个节点一开始的值为0,现在有以下这几种操作:
1:给你两个数x,y,以及一个数z,表示把x到y的路径上的每个节点权值加z
2:给你一个数x,表示查询编号为x的节点及其子树节点的权值和
3:给你两个数x,y,表示查询x到y路径上的每个节点的权值和
现在给出n及这个树的边还有一个数m表示操作个数。
输入描述
两个数n,m
下面n-1行每行两个整数u,v,表示u,v之间有一条边
下面m行,每行第一个整数q,只为1,2,3中间的一个数
当q为1时:接下来三个整数x,y,z,如同题面操作
为2时:接下来一个整数x,如同题面操作
为3时:接下来两个整数x,y,如同题面操作
输出描述
对于每个q=1或2,输出一个整数表示答案
Code
#include<iostream>
using namespace std;
const int N=200005;
bool vis[N];
int n,sum,cnt,last[N],map[N],l[N],siz[N],dis[N],bson[N],larf[N],fa[N],num[N],lazy[N],sumval[N],ans;
void ad(int u,int v)
{
sum++;
map[sum]=v;
last[sum]=l[u];
l[u]=sum;
}
void build1(int x,int ds)
{
siz[x]=1;//用于记录子树大小
dis[x]=ds;//用于记录深度
vis[x]=false;
bson[x]=0;//记录最大儿子的编号
for (int i=l[x];i!=0;i=last[i])
{
if (vis[map[i]])
{
fa[map[i]]=x;
build1(map[i],ds+1);
siz[x]+=siz[map[i]];
if (siz[map[i]]>siz[bson[x]]) bson[x]=map[i];
}
}
}
void build2(int x,int bgfa)
{
num[x]=++cnt;//分一个dfs序编号给它
vis[x]=false;
larf[x]=bgfa;//记录所在重链顶端
if (bson[x]==0) return;
build2(bson[x],bgfa);//为了一条重链的编号连在一起,先走重儿子
for (int i=l[x];i!=0;i=last[i])
{
if (vis[map[i]]&&bson[x]!=map[i])
{
build2(map[i],map[i]);//再走其他儿子,以那个儿子为新的链顶开一条链
}
}
}
void down(int x,int y,int l,int r)
{
lazy[y]+=lazy[x];
sumval[y]=sumval[y]+(r-l+1)*lazy[x];
}
void change(int now,int l,int r,int mbl,int mbr,int as)
{
if (l>mbr||r<mbl) return;
if (l>=mbl&&r<=mbr)
{
lazy[now]+=as;
sumval[now]+=as*(r-l+1);
return;
}
int mid=(l+r)/2;
if (lazy[now]!=0)
{
down(now,now*2,l,mid);
down(now,now*2+1,mid+1,r);
lazy[now]=0;
}
change(now*2,l,mid,mbl,mbr,as);
change(now*2+1,mid+1,r,mbl,mbr,as);
sumval[now]=sumval[now*2]+sumval[now*2+1];
}
void add(int x,int y,int z)
{
while (larf[x]!=larf[y])
{
if (dis[larf[x]]>dis[larf[y]]) swap(x,y);
change(1,1,n,num[larf[y]],num[y],z);
y=fa[larf[y]];
}
if (dis[x]>dis[y]) swap(x,y);
change(1,1,n,num[x],num[y],z);
}
int query(int now,int l,int r,int mbl,int mbr)
{
if (l>mbr||r<mbl) return 0;
if (l>=mbl&&r<=mbr)
{
return sumval[now];
}
int mid=(l+r)/2;
if (lazy[now]!=0)
{
down(now,now*2,l,mid);
down(now,now*2+1,mid+1,r);
lazy[now]=0;
}
return(query(now*2,l,mid,mbl,mbr)+query(now*2+1,mid+1,r,mbl,mbr));
}
void query2(int x,int y)
{
while (larf[x]!=larf[y])
{
if (dis[larf[x]]>dis[larf[y]]) swap(x,y);
ans+=query(1,1,n,num[larf[y]],num[y]);
y=fa[larf[y]];
}
if (dis[x]>dis[y]) swap(x,y);
ans+=query(1,1,n,num[x],num[y]);
}
int main()
{
int t,u,v,q,x,y,z;
cin>>n>>t;
for (int i=1;i<=n;i++) vis[i]=true;
for (int i=1;i<=n-1;i++)
{
cin>>u>>v;
ad(u,v);
ad(v,u);
}
build1(1,1);//第一次dfs
for (int i=1;i<=n;i++) vis[i]=true;
build2(1,1);//第二次dfs
for (int i=1;i<=t;i++)
{
cin>>q;
if (q==1)
{
cin>>x>>y>>z;
add(x,y,z);
}
if (q==2)
{
cin>>x;
ans=query(1,1,n,num[x],num[x]+siz[x]-1);
cout<<ans<<endl;
}
if (q==3)
{
cin>>x>>y;
ans=0;
query2(x,y);
cout<<ans<<endl;
}
}
}
这次的博客就到这里,希望能给大家一些帮助