刚学了树链剖分,做了几题熟悉了一下,在此总结一下树剖的步骤吧
准备
deep[]:记录该节点的深度
siz[]:记录以该节点为根的子树的节点数
fa[]:记录该节点的父亲是谁
son[]:记录该节点的重儿子是谁
top[]:记录该节点所在的重链的根节点是谁
w[]:记录该节点投影到数轴后的位置(即DFS序,线段树要用)
具体步骤
- 读入,用空间池储存边(注意是双向的);
- 第一次DFS:
记录dep[],siz[],fa[],son[](找siz最大的儿子做重儿子) 第二次DFS:
记录w[];
顺着son[]处理重链记录top[](包括自己)
之后处理轻链
【伪代码】
设现处理x节点,其重父亲为tp
dfs2(x,tp)
time++;
w[x]=time;
top[x]=tp;
if(有重儿子){
dfs2(重儿子,tp);
枚举其他儿子v->dfs2(v,v);
}建线段树
- 修改:从x到y的节点加d
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
insert(1,1,n,w[top[x]],w[x],d);
x=fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
insert(1,1,n,w[x],w[y],d);
6.查询(和线段树一样就不说了)
贴个代码吧,hdu3966,比较经典的树剖题
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <iostream>
#include <string>
#include <algorithm>
using namespace std;
#define zero(a) memset(a,0,sizeof(a))
#define minus(a) memset(a,-1,sizeof(a))
const int MAX_N = 50100;
struct edge{
int idx;
int next;
}e[MAX_N * 3];
int siz[MAX_N];
int dep[MAX_N];
int fa[MAX_N];
int son[MAX_N];
int top[MAX_N];
int w[MAX_N];
int a[MAX_N];
int h[MAX_N];
int n,m,q,ep,time;
int cnt[MAX_N * 4];
inline void add(int x,int y){
ep++;
e[ep].idx = y;
e[ep].next = h[x];
h[x] = ep;
return;
}
void dfs1(int x,int fat,int deep){
dep[x] = deep;
siz[x] = 1;
fa[x] = fat;
for(int i=h[x];i!=-1;i=e[i].next){
int idx = e[i].idx;
if(idx == fat)
continue;
dfs1(idx,x,deep+1);
siz[x]+=siz[idx];
if(son[x]==-1 || siz[idx]>siz[son[x]])
son[x]=idx;
}
return;
}
void dfs2(int x,int tp){
time++;
w[x]=time;
top[x] = tp;
if(son[x]!=-1){
dfs2(son[x],tp); //ÖØÁ´
for(int i=h[x];i!=-1;i=e[i].next){
int v = e[i].idx;
if(v!=son[x] && v!=fa[x]) dfs2(v,v); //ÇáÁ´
}
}
return;
}
void insert(int rt,int l,int r,int a,int b,int val){
if(b<l||a>r)
return;
if(a<=l&&b>=r){
cnt[rt]+=val;
return;
}
int lson = rt*2;
int rson = rt*2+1;
int mid = (l+r)/2;
insert(lson,l,mid,a,b,val);
insert(rson,mid+1,r,a,b,val);
return;
}
inline void revise(int x,int y,int d){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
insert(1,1,n,w[top[x]],w[x],d);
x=fa[top[x]];
}
if(dep[x] > dep[y]) swap(x,y);
insert(1,1,n,w[x],w[y],d);
return;
}
int query(int rt,int l,int r,int idx,int sum){
if(l==r)
return sum+cnt[rt];
int lson = rt*2;
int rson = rt*2+1;
int mid = (l+r)/2;
if(mid>=idx)
return query(lson,l,mid,idx,sum+cnt[rt]);
else
return query(rson,mid+1,r,idx,sum+cnt[rt]);
}
inline void ask(int x){
//for(int i=1;i<=n;i++)
//cout<<i<<" "<<cnt[i]<<endl;
int add = query(1,1,n,w[x],0);
printf("%d\n",a[x]+add);
return;
}
inline void init(){
zero(siz);
zero(dep);
zero(fa);
zero(top);
zero(w);
zero(e);
zero(cnt);
zero(a);
minus(son);
minus(h);
ep=0;
time=0;
return;
}
inline void read(){
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=m;i++){
int x,y;
scanf("%d %d",&x,&y);
add(x,y);
add(y,x);
}
return;
}
inline void build(){
dfs1(1,1,1);
dfs2(1,1);
return;
}
inline void solve(){
char c[5];
int x,y,d;
for(int i=1;i<=q;i++){
scanf("%s",&c);
if(c[0]=='Q'){
scanf("%d",&x);
ask(x);
continue;
}
scanf("%d %d %d",&x,&y,&d);
if(c[0]=='D')
d=-d;
revise(x,y,d);
}
return;
}
int main(){
while(scanf("%d %d %d",&n,&m,&q)!=EOF){
init();
read();
build();
solve();
}
return 0;
}