树链剖分
传送门:洛谷P3384
预备知识
1.dfs序
2.链式前向星 ,这个比stl的vector要快很多,而且代码也不是很长,不会建议学一下
3.线段树/树状数组
树链剖分
树链剖分,指一种对树进行划分的算法 (你别告诉我你不知道什么是树) ,它先通过轻重边剖分将树分为多条链,保证每个点属于且只属于一条链,然后再通过数据结构(树状数组、BST、SPLAY、线段树等)来维护每一条链。~~
1.可能涉及到的名词解释
1.重儿子:非子叶结点的儿子中儿子最多的结点为重儿子(有点绕)
2.轻儿子:非子叶结点的儿子中除掉重儿子之外的儿子结点
3.重边:连接两个重儿子结点的边
4.轻边:除了重边之外的边
4.重链:一条链上全是重边连起来的链
性质:每一条重链都是从轻儿子开始的(根结点是轻儿子)
2.实现
1.求出所有结点的深度,父节点,子树大小以及重儿子(子叶结点除外)实现代码如下:
变量解释:
depth[x]:为结点深度
ather[x]:x结点的父节点
size[x]:x结点为根节点组成的子树的大小
maxson:重儿子的儿子数
son[x]:x结点的重儿子
inline void dfs1(int nowp,int fa){
depth[nowp]=depth[fa]+1;
father[nowp]=fa;
size[nowp]=1;
int maxson=-1;
for(int i=head[nowp];i;i=edge[i].next){
int to=edge[i].to;
if(to==fa)continue;
dfs1(to,nowp);
size[nowp]+=size[to];
if(size[to]>maxson)son[nowp]=to,maxson=size[to];
}
}
2.对所有的结点重新编号,求出所有结点的链顶端结点
变量解释
id[x]:x结点的新编号
top[x]:x结点的链顶端结点
value[x]:x结点的值
new_value[x]:x(x为新节点)结点的值
inline void dfs2(int nowp,int topf){
id[nowp]=++cnt;
new_value[cnt]=value[nowp];
top[nowp]=topf;
if(!son[nowp])return;
dfs2(son[nowp],topf);
for(int i=head[nowp];i;i=edge[i].next){
int to=edge[i].to;
if(to==father[nowp]||to==son[nowp])continue;
dfs2(to,to);
}
}
到这里,树已经被重新分割成了若干条独立的轻重链,并且每一条链的序号都是连续的,这一点也很好理解,递归嘛
3.现在我们来处理题目给出的4重操作
1.处理任意两点间路径上的点权和
2.处理一点及其子树的点权和
3.修改任意两点间路径上的点权
4.修改一点及其子树的点权
实现代码如下:
inline int sum_range(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(depth[top[x]]<depth[top[y]])
swap(x,y);
ans+=sum(id[x],id[top[x]]);
ans%=mod;
x=father[top[x]];
}
if(depth[x]<depth[y])swap(x,y);
ans+=sum(id[x],id[y]);
return ans%mod;
}
inline int sum_son(int x){
return sum(id[x],id[x]+size[x]-1);
}
inline int update_range(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(depth[top[x]]<depth[top[y]])swap(x,y);
add2(id[top[x]],id[x],k);
x=father[top[x]];
}
if(depth[x]<depth[y])swap(x,y);
add2(id[y],id[x],k);
}
inline void update_son(int x,int k){
k%=mod;
add2(id[x],id[x]+size[x]-1,k);
}
最后在加上线段树板子就大功告成啦,下面附上ac代码(我用的是树状数组)
#include<stdio.h>
#include<iostream>
using namespace std;
typedef long long ll;
const int MAX_N=5e5+10;
int bit0[MAX_N],bit1[MAX_N];
int head[MAX_N],value[MAX_N],new_value[MAX_N];
int cnt=0,ccnt=1;
int depth[MAX_N],father[MAX_N],size[MAX_N],son[MAX_N],top[MAX_N],id[MAX_N];
int n,mod;
struct Edge{
int to;
int next;
}edge[2*MAX_N];
inline void add_edge(int from, int to){
edge[ccnt].to = to;
edge[ccnt].next = head[from];
head[from] = ccnt++;
}
inline ll fsum(int *b,int i){
ll s=0;
while(i>0){
s+=b[i];
i-=i&-i;
}
return s%mod;
}
inline void add(int *b,int i,int x){
x%=mod;
while(i<=n){
b[i]=(x+b[i])%mod;
i+=i&-i;
}
}
inline void add2(int l,int r,int x){
x%=mod;
add(bit0,l,(-x*(l-1)%mod+mod)%mod);
add(bit1,l,x);
add(bit0,r+1,x*r%mod);
add(bit1,r+1,-x);
}
inline ll sum(int l,int r){
if(l>r)swap(l,r);
ll res=0;
res+=fsum(bit0,r)+fsum(bit1,r)*r;
res-=fsum(bit0,l-1)+fsum(bit1,l-1)*(l-1);
return (res%mod+mod)%mod;
}
inline void dfs1(int nowp,int fa){
depth[nowp]=depth[fa]+1;
father[nowp]=fa;
size[nowp]=1;
int maxson=-1;
for(int i=head[nowp];i;i=edge[i].next){
int to=edge[i].to;
if(to==fa)continue;
dfs1(to,nowp);
size[nowp]+=size[to];
if(size[to]>maxson)son[nowp]=to,maxson=size[to];
}
}
inline void dfs2(int nowp,int topf){
id[nowp]=++cnt;
new_value[cnt]=value[nowp];
top[nowp]=topf;
if(!son[nowp])return;
dfs2(son[nowp],topf);
for(int i=head[nowp];i;i=edge[i].next){
int to=edge[i].to;
if(to==father[nowp]||to==son[nowp])continue;
dfs2(to,to);
}
}
inline int sum_range(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(depth[top[x]]<depth[top[y]])
swap(x,y);
ans+=sum(id[x],id[top[x]]);
ans%=mod;
x=father[top[x]];
}
if(depth[x]<depth[y])swap(x,y);
ans+=sum(id[x],id[y]);
return ans%mod;
}
inline int sum_son(int x){
return sum(id[x],id[x]+size[x]-1);
}
inline int update_range(int x,int y,int k){
k%=mod;
while(top[x]!=top[y]){
if(depth[top[x]]<depth[top[y]])swap(x,y);
add2(id[top[x]],id[x],k);
x=father[top[x]];
}
if(depth[x]<depth[y])swap(x,y);
add2(id[y],id[x],k);
}
inline void update_son(int x,int k){
k%=mod;
add2(id[x],id[x]+size[x]-1,k);
}
int main(){
int m,root;
scanf("%d%d%d%d",&n,&m,&root,&mod);
for(int i=1;i<=n;++i)cin>>value[i];
for(int i=1;i<=n-1;++i){
int f,t;
scanf("%d%d",&f,&t);
add_edge(f,t);
add_edge(t,f);
}
dfs1(root,0);
dfs2(root,root);
for(int i=1;i<=n;++i)add(bit0,i,new_value[i]);
for(int i=1;i<=m;++i){
int ask;cin>>ask;
if(ask==1){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
update_range(x,y,z);
}
else if(ask==2){
int x,y;
scanf("%d%d",&x,&y);
printf("%d\n",sum_range(x,y));
}
else if(ask==3){
int x,z;
scanf("%d%d",&x,&z);
update_son(x,z);
}
else if(ask==4){
int x;
scanf("%d",&x);
printf("%d\n",sum_son(x));
}
}
return 0;
}