树链剖分
发现最近几天可以出专题了。。。近几天搞板子题真的是逼着我写一些东西。。。那么我们来搞一搞树链剖分
原理
树链剖分,实际上就是一种把树结构映射到一颗线段树结构上的算法,常用于搞各种树上的两点路径查询及修改的问题,但树的形态不能改变,否则要改用LCT,然而我还不会hehe
我们记录如下的一些东西:
top 数组,用于记录每个点所在的链的顶端的节点
dfs 数组,用于记录每个点在 DFS 序列中的位置
size 数组,记录每个节点的子树大小
son 数组,记录每个点的重儿子
fa 数组,记录每个点的父亲
dep 数组,记录每个节点的深度
i2x 数组,就我个人而言,这是我的习惯,是 dfs 的一个反映射
现在说一下最常用的树剖方法:轻重树链剖分法,每次我们找一个节点的儿子中子树大小最大的那一个,然后把它与原节点归到同一个链子上,其余的节点再分别作为其他链子的顶端接着搞
其实挺简单的,就是代码量大,调试困难,容易写错,浪费时间罢了
代码
恩,下面这个是一个带两点路径查询以及两点路径修改的树剖,查询时查找SUM和MAX
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<vector>
#define maxn 100005
#define time t
using namespace std;
int t=0,n,m,qx,qy,qd,dep[maxn],size[maxn],fa[maxn],id[maxn],id2x[maxn],son[maxn],top[maxn],line[maxn],sum[maxn*2],maxnum[maxn*2];
vector<int> geo[maxn];
void DFS1(int u){
dep[u]=dep[fa[u]]+1;
size[u]=1;
son[u]=0;
for(int i=0;i<geo[u].size();i++){
int op=geo[u][i];
if(op==fa[u])continue;
fa[op]=u;
DFS1(op);
size[u]+=size[op];
if(size[son[u]]<size[op]){
son[u]=op;
}
}
}
void DFS2(int x,int tp){
top[x]=tp;
id[x]=++time;
id2x[time]=x;
if(son[x])DFS2(son[x],tp);
for(int i=0;i<geo[x].size();i++){
int op=geo[x][i];
if(op==fa[x]||op==son[x])continue;
DFS2(op,op);
}
}
void build(int l,int r,int o){
if(l==r){
sum[o]=line[id2x[l]],maxnum[o]=line[id2x[l]];
return;
}
int mid=((r-l)>>1)+l;
build(l,mid,o<<1);
build(mid+1,r,(o<<1)+1);
sum[o]=sum[o<<1]+sum[(o<<1)+1];
maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
return;
}
void UPDATE(int l,int r,int o){
if(qx<=l&&r<=qy){
sum[o]=qd,maxnum[o]=qd;
return;
}
int mid=((r-l)>>1)+l;
if(qx<=mid)UPDATE(l,mid,o<<1);
if(mid<qy)UPDATE(mid+1,r,(o<<1)+1);
sum[o]=sum[o<<1]+sum[(o<<1)+1];
maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
return;
}
int get_sum(int l,int r,int o){
if(qx<=l&&r<=qy){
return sum[o];
}
int mid=((r-l)>>1)+l,ans=0;
if(qx<=mid){
ans+=get_sum(l,mid,o<<1);
}
if(qy>mid){
ans+=get_sum(mid+1,r,(o<<1)+1);
}
return ans;
}
int get_max(int l,int r,int o){
if(qx<=l&&r<=qy){
return maxnum[o];
}
int mid=((r-l)>>1)+l,ans=0;
if(qx<=mid){
ans=max(ans,get_max(l,mid,o<<1));
}
if(mid<qy)
ans=max(ans,get_max(mid+1,r,(o<<1)+1));
return ans;
}
void init(){
DFS1(1);
DFS2(1,1);
build(1,n,1);
}
void MAX(int x,int y){
int f1=top[x];
int f2=top[y];
int ans=0;
while(f1!=f2){
if(dep[f1]>dep[f2]){
qx=id[f1],qy=id[x];
ans=max(ans,get_max(1,n,1));
x=fa[f1];
}
else{
qx=id[f2],qy=id[y];
ans=max(ans,get_max(1,n,1));
y=fa[f2];
}
f1=top[x];
f2=top[y];
}
if(dep[x]>dep[y])swap(x,y);
qx=id[x],qy=id[y];
ans=max(ans,get_max(1,n,1));
printf("%d\n",ans);
}
void SUM(int x,int y){
int f1=top[x];
int f2=top[y];
int ans=0;
while(f1!=f2){
if(dep[f1]>dep[f2]){
qx=id[f1],qy=id[x];
ans+=get_sum(1,n,1);
x=fa[f1];
}
else{
qx=id[f2],qy=id[y];
ans+=get_sum(1,n,1);
y=fa[f2];
}
f1=top[x];
f2=top[y];
}
if(dep[x]>dep[y])swap(x,y);
qx=id[x],qy=id[y];
ans+=get_sum(1,n,1);
printf("%d\n",ans);
}
void work(int op,int x,int y){
if(op==0){
qx=id[x],qy=id[x],qd=y;
UPDATE(1,n,1);
}
else if(op==1){
MAX(x,y);
}
else{
SUM(x,y);
}
return;
}
int main(){
/*freopen("input.txt","r",stdin);
freopen("output.txt","w",stdout);*/
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&line[i]);
int x,y;
char ch[10];
for(int i=0;i<n-1;i++){
scanf("%d%d",&x,&y);
geo[y].push_back(x);
geo[x].push_back(y);
}
init();
for(int i=0;i<m;i++){
scanf("%s%d%d",ch,&x,&y);
if(ch[0]=='U'){
work(0,x,y);
}
else if(ch[0]=='M'){
work(1,x,y);
}
else{
work(2,x,y);
}
}
return 0;
}
下面这个是带两点路径查询/修改以及子树查询/修改的代码,查询时查找SUM,同时对M取模(这其实是洛谷上的模板题的板子,破事巨多,数据范围还有问题,而且容易爆栈,真是服了)
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<vector>
#define maxn 200005
#define time t
#define LL long long int
using namespace std;
int t=0,root,M,n,m,qx,qy,dep[maxn],size[maxn],fa[maxn],id[maxn],id2x[maxn],son[maxn],top[maxn];
LL qd,sum[maxn*2],/*maxnum[maxn*2],*/add[maxn*2],line[maxn];
vector<int> geo[maxn];
void DFS1(int u){
dep[u]=dep[fa[u]]+1;
size[u]=1;
son[u]=0;
for(int i=0;i<geo[u].size();i++){
int op=geo[u][i];
if(op==fa[u])continue;
fa[op]=u;
DFS1(op);
size[u]+=size[op];
if(size[son[u]]<size[op]){
son[u]=op;
}
}
}
void DFS2(int x,int tp){
top[x]=tp;
id[x]=++time;
id2x[time]=x;
if(son[x])DFS2(son[x],tp);
for(int i=0;i<geo[x].size();i++){
int op=geo[x][i];
if(op==fa[x]||op==son[x])continue;
DFS2(op,op);
}
}
void build(int l,int r,int o){
if(l==r){
sum[o]=line[id2x[l]]/*,maxnum[o]=line[id2x[l]]*/;
return;
}
int mid=((r-l)>>1)+l;
build(l,mid,o<<1);
build(mid+1,r,(o<<1)+1);
sum[o]=(sum[o<<1]+sum[(o<<1)+1])%M;
//maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
return;
}
void update(int l,int r,int o){
if(qx<=l&&r<=qy){
add[o]=(add[o]+qd)%M;sum[o]=(sum[o]+(((r-l+1)%M)*qd%M))%M;//maxnum[o]+=a;
return;
}
int mid=((r-l)>>1)+l;
if(qx<=mid)update(l,mid,o<<1);
if(mid<qy)update(mid+1,r,(o<<1)+1);
sum[o]=((sum[o<<1]+sum[(o<<1)+1])%M+((r-l+1)%M*add[o])%M)%M;
//maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
return;
}
LL get_sum(int l,int r,int o,LL a){
if(qx<=l&&r<=qy){
return (sum[o]+a%M*(r-l+1)%M)%M;
}
int mid=((r-l)>>1)+l;
LL ans=0;
if(qx<=mid){
ans=(ans+get_sum(l,mid,o<<1,a+add[o]))%M;
}
if(qy>mid){
ans=(ans+get_sum(mid+1,r,(o<<1)+1,a+add[o]))%M;
}
return ans;
}
/*int get_max(int l,int r,int o){
if(qx<=l&&r<=qy){
return maxnum[o];
}
int mid=((r-l)>>1)+l,ans=0;
if(qx<=mid){
ans=max(ans,get_max(l,mid,o<<1));
}
if(mid<qy)
ans=max(ans,get_max(mid+1,r,(o<<1)+1));
return ans;
}*/
void init(){
DFS1(root);
DFS2(root,root);
build(1,n,1);
}
void UPDATE(int x,int y,int z){
int f1=top[x];
int f2=top[y];
while(f1!=f2){
if(dep[f1]>dep[f2]){
qx=id[f1],qy=id[x],qd=z;
update(1,n,1);
x=fa[f1];
}
else{
qx=id[f2],qy=id[y],qd=z;
update(1,n,1);
y=fa[f2];
}
f1=top[x];
f2=top[y];
}
if(dep[x]>dep[y])swap(x,y);
qx=id[x],qy=id[y],qd=z;
update(1,n,1);
}
/*void MAX(int x,int y){
int f1=top[x];
int f2=top[y];
int ans=0;
while(f1!=f2){
if(dep[f1]>dep[f2]){
qx=id[f1],qy=id[x];
ans=max(ans,get_max(1,n,1,0));
x=fa[f1];
}
else{
qx=id[f2],qy=id[y];
ans=max(ans,get_max(1,n,1,0));
y=fa[f2];
}
f1=top[x];
f2=top[y];
}
if(dep[x]>dep[y])swap(x,y);
qx=id[x],qy=id[y];
ans=max(ans,get_max(1,n,1,0));
printf("%d\n",ans);
}*/
void SUM(int x,int y){
int f1=top[x];
int f2=top[y];
int ans=0;
while(f1!=f2){
if(dep[f1]>dep[f2]){
qx=id[f1],qy=id[x];
ans=(ans+get_sum(1,n,1,0))%M;
x=fa[f1];
}
else{
qx=id[f2],qy=id[y];
ans=(ans+get_sum(1,n,1,0))%M;
y=fa[f2];
}
f1=top[x];
f2=top[y];
}
if(dep[x]>dep[y])swap(x,y);
qx=id[x],qy=id[y];
ans=(ans+get_sum(1,n,1,0))%M;
printf("%lld\n",ans%M);
}
void SUBTREE_UPDATE(int x,int z){
qx=id[x],qy=id[x]+size[x]-1,qd=z;
update(1,n,1);
}
void SUBTREE_SUM(int x){
qx=id[x],qy=id[x]+size[x]-1;
printf("%lld\n",get_sum(1,n,1,0)%M);
}
int main(){
/*freopen("input.txt","r",stdin);
freopen("output.txt","w",stdout);*/
scanf("%d%d%d%d",&n,&m,&root,&M);
for(int i=1;i<=n;i++)
scanf("%d",&line[i]);
LL x,y,z,op;
for(int i=0;i<n-1;i++){
scanf("%d%d",&x,&y);
geo[y].push_back(x);
geo[x].push_back(y);
}
init();
for(int i=0;i<m;i++){
scanf("%d",&op);
if(op==1){
scanf("%d%d%lld",&x,&y,&z);
UPDATE(x,y,z);
}
else if(op==2){
scanf("%d%d",&x,&y);
SUM(x,y);
}
else if(op==3){
scanf("%d%lld",&x,&y);
SUBTREE_UPDATE(x,y);
}
else{
scanf("%d",&x);
SUBTREE_SUM(x);
}
}
return 0;
}
然而本人代码还是非常地丑QAQ
细节
需要注意的就是线段数的那个部分以及查询/修改刚开始的部分,每次判断 top 的深度,然后不断地向上移动
就是DFS什么的别写错就行啦
总结
树链剖分是个好东西,我们可以用它来搞许多事情