什么是树链剖分
将一棵树分成几条不相交的链,使得这些链中所以节点的标号在区间内,这样我们就可以使用数据结构去维护这些链上的信息。
一些概念&定义
- deep[x]:节点x在树中的深度
- size[x]:以x为根的子树内的节点数
- 重儿子:一个节点的儿子节点中,size最大的那个。每个节点都只有一个重儿子。
- 重链:几个连续的重儿子组成的链
- top[x]:节点x所在重链的链顶(深度最小的节点)
- fa[x]:表示一个节点的父亲
如何去分链
一般我们都会使用轻重链剖分,即将几个连续的重儿子组成的链看做重链。
我们可以先用一个dfs将每个节点的size和deep处理出来。
在再次dfs整棵树,优先遍历重儿子,将整棵树重新标号,并处理处top。这时我们发现每条重链的节点都在一个连续的区间上,这样我们方便使用数据结构去维护重链上的信息。
以1所在的重链为例,1为链顶,边上的标号为1~4。
应用
使用树剖我们可以求出两点之间的lca,并维护路径上的最大值、和等。
每次我们将deep[top[x]]较大的节点跳到fa[top[x]],直到两点在同一条重链上。
例题
jzoj 2256、luogu P2590 [ZJOI2008]树的统计
code
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N=3e4+10;
int n,m,p[N],deep[N],last[N],size[N],fa[N],top[N],b[N],tr[N*4][2],v[N],z1,z2;
struct note
{
int a,b;
}a[N*2];
void add(int x,int y) {
a[++a[0].a].a=y;
a[a[0].a].b=last[x];
last[x]=a[0].a;
}
void dfs1(int x,int y) {
deep[x]=deep[y]+1;
fa[x]=y;
for (int i=last[x];i!=0;i=a[i].b) {
if (a[i].a!=y) {
dfs1(a[i].a,x);
size[x]+=size[a[i].a];
}
}
size[x]+=1;
}
void dfs2(int x,int y) {
b[x]=++b[0];
int z1=0,z2=0;
for (int i=last[x];i!=0;i=a[i].b) {
if (a[i].a!=y) {
if (z1<size[a[i].a] && a[i].a!=y) {
z1=size[a[i].a];
z2=a[i].a;
}
}
}
if (z2!=0) {
top[z2]=top[x];
dfs2(z2,x);
}
for (int i=last[x];i!=0;i=a[i].b) {
if (a[i].a!=y && a[i].a!=z2) {
top[a[i].a]=a[i].a;
dfs2(a[i].a,x);
}
}
}
void query(int x,int y,int t ,int l,int r) {
if (x==l && r==y) {
z1=max(z1,tr[t][0]);
z2+=tr[t][1];
}
else {
int mid=l+r>>1;
if (mid>=y) query(x,y,t*2,l,mid);
else {
if (mid<x) query(x,y,t*2+1,mid+1,r);
else {query(mid+1,y,t*2+1,mid+1,r);query(x,mid,t*2,l,mid);}
}
}
}
void change(int t,int l,int r,int x) {
if (l==r) {
tr[t][0]=tr[t][1]=v[l];
}
else {
int mid=l+r>>1;
if (mid>=x) change(t*2,l,mid,x); else change(t*2+1,mid+1,r,x);
tr[t][0]=max(tr[t*2][0],tr[t*2+1][0]);
tr[t][1]=tr[t*2][1]+tr[t*2+1][1];
}
}
void work(int x,int y) {
while (1) {
if (top[x]!=top[y]) {
if (deep[top[x]]<deep[top[y]]) {
query(b[top[y]],b[y],1,1,b[0]);
y=fa[top[y]];
}
else {
query(b[top[x]],b[x],1,1,b[0]);
x=fa[top[x]];
}
}
else {
if (deep[x]>deep[y]) swap(x,y);
query(b[x],b[y],1,1,b[0]);
break;
}
}
}
void build(int t,int l,int r) {
if (l==r) {
tr[t][0]=tr[t][1]=v[l];
}
else {
int mid=l+r>>1;
build(t*2,l,mid);build(t*2+1,mid+1,r);
tr[t][0]=max(tr[t*2][0],tr[t*2+1][0]);
tr[t][1]=tr[t*2][1]+tr[t*2+1][1];
}
}
int main() {
int i,j,k;
scanf("%d",&n);
for (i=1;i<=n-1;i++) {
int x,y;
scanf("%d%d",&x,&y);
add(x,y);add(y,x);
}
dfs1(1,0);
top[1]=1;
dfs2(1,0);
for (i=1;i<=n;i++) scanf("%d",&v[b[i]]);
build(1,1,b[0]);
scanf("%d",&m);
for (i=1;i<=m;i++) {
int x,y;
scanf("\n");
char c1,c=getchar();
if (c=='Q') {
c1=getchar();
if (c1=='M') scanf("AX");else scanf("UM");
scanf("%d%d",&x,&y);
z1=-2100000000;z2=0;
work(x,y);
if (c1=='M') printf("%d\n",z1);else printf("%d\n",z2);
}
else{
scanf("HANGE");scanf("%d%d",&x,&y);
v[b[x]]=y;
change(1,1,b[0],b[x]);
}
}
}