笔记
树链剖分:顾名思义,就是对树的链不断地剖,不断地分。
链的分类:
1、重链剖分
2、长链剖分
- 重链剖分:每次从最重的儿子剖下去
以sz[u]表示u的结点个数 - 长链剖分:每次找最深的儿子剖下去
为啥重链剖分比长链剖分好???
- 如果对于一棵树,每个节点的儿子都很深
- 长链剖分就可能会经过 n \sqrt n n条链
- 直接崩掉
证明:重链剖分 O ( l o g n ) O(logn) O(logn)
- 对于每个点,假设只有一个重儿子
- s z [ u ] > = s z [ h e a v y u ] + s z [ v ] + 1 sz[u]>=sz[heavy_u]+sz[v]+1 sz[u]>=sz[heavyu]+sz[v]+1
- 然后不会了。。
- //抱歉,这里只是笔记,当听到这里的时候后面就没跟上。。将就着吧。
实现方式:
1、每条链建立一棵线段树
2、建一棵线段树
- 先做一遍dfs,预处理出每个点的字子树个数以及重儿子编号等等
- 每次dfs,优先访问重儿子
- 好处???
仍然是dfs序,重链都连在了一起,且是连续的一段 - 然后就可以用top[x]记录x所属的重链的头号元素
- 然后就可以用dfs序处理一系列问题
- 具体实现请看程序。
听说树链剖分一般用来实现树上路径问题???
例题讲解
题目传送门
题目描述:
一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。
我们将以下面的形式来要求你对这棵树完成一些操作:
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
注意:从点u到点v的路径上的节点包括u和v本身
Solution
首先,这道题是求树上路径问题,我们这时就应该想到两个算法:点分治和树链剖分
因为,,这两个算法似乎都是用来求树上路径问题的???也许吧。。。
不过,这题是在线的,点分治似乎并不是那么好实现,所以我们想树链剖分。
这题其实就是树链剖分的模板题。
先将树求一遍dfs序,问题转化为区间求和、最大值问题。
考虑线段树。
dfs序、预处理等完善后,就可以进行树链剖分。
过程其实类似于倍增求LCA的过程
对于一段区间求和,将区间对应到树上,每次找两个点所对应的重链的顶端,不断对顶端到那个点的区间求和,一直往上递归即可,知道两点处于同一条链中,这时就可以求两点的区间。
用线段树很好实现。
Code
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=60000;
int n,m;
int cnt=0,len=0;
struct node{
int y,Next;
}e[N*2];
struct Node{
int l,r,maxx,s;
};
int a[N*2];
int linkk[N*2];
int dfn[N*2];
int top[N*2];
int fa[N*2];
int pos[N*2];
int hv[N*2];
int sz[N*2];
int d[N*2];
void insert(int x,int y){
e[++len]=(node){y,linkk[x]};
linkk[x] = len;
}
void dfshv(int x,int faa,int dd){
d[x] = dd;
sz[x] = 1;
fa[x] = faa;
for (int i=linkk[x];i;i=e[i].Next){
int y = e[i].y;
if (y == faa) continue;
dfshv(y,x,dd+1);
sz[x] += sz[y];
if (sz[hv[x]] < sz[y]) hv[x] = y;
}
}
void dfs(int x,int rt){
top[x] = rt;
dfn[x] = ++cnt;
pos[cnt] = x;
if (hv[x] > 0) dfs(hv[x],rt);
for (int i=linkk[x];i;i=e[i].Next){
int y=e[i].y;
if (y == fa[x] || y == hv[x]) continue;
dfs(y,y);
}
}
struct Tr{
Node tr[N*2+100];
void build(int p,int l,int r){
tr[p].l=l,tr[p].r=r;
tr[p].s=0,tr[p].maxx=-10000000;
if (l == r) {tr[p].s = tr[p].maxx = a[pos[l]];return;}
int mid=l+r>>1;
build(p<<1,l,mid),build(p<<1|1,mid+1,r);
tr[p].maxx = max(tr[p<<1].maxx , tr[p<<1|1].maxx);
tr[p].s = tr[p<<1].s + tr[p<<1|1].s;
}
void change(int p,int x,int v){
int L=tr[p].l,R=tr[p].r;
if (L == R && L == x) {tr[p].maxx = tr[p].s = v;return;}
int mid = (L+R)>>1;
if (x<=mid) change(p<<1,x,v);
else change(p<<1|1,x,v);
tr[p].maxx = max(tr[p<<1].maxx , tr[p<<1|1].maxx);
tr[p].s = tr[p<<1].s + tr[p<<1|1].s;
}
int askm(int p,int l,int r){
int L=tr[p].l,R=tr[p].r;
if (L>r || R<l) return -10000000;
if (l<=L && R<=r) return tr[p].maxx;
int mid=L+R>>1;
int Maxx=-10000000;
if (l<=mid) Maxx=max(Maxx,askm(p<<1,l,r));
if (mid < r) Maxx=max(Maxx,askm(p<<1|1,l,r));
return Maxx;
}
int asks(int p,int l,int r){
int L=tr[p].l,R=tr[p].r;
if (L>r || R<l) return 0;
if (l<=L && R<=r) return tr[p].s;
int mid=L+R>>1;
int ss=0;
if (l<=mid) ss+=asks(p<<1,l,r);
if (mid < r) ss+=asks(p<<1|1,l,r);
return ss;
}
};
Tr tr;
int qmax(int u,int v){
int ans=-100000000;
for (;top[u] != top[v];u=fa[top[u]]){
if (d[top[u]] < d[top[v]]) swap(u,v);
ans = max(ans , tr.askm(1,dfn[top[u]],dfn[u]));
}
if (d[u] > d[v]) swap(u,v);
ans = max(ans,tr.askm(1,dfn[u],dfn[v]));
return ans;
}
int qsum(int u,int v){
int ans=0;
for (;top[u] != top[v] ; u=fa[top[u]]){
if (d[top[u]] < d[top[v]]) swap(u,v);
ans += tr.asks(1,dfn[top[u]],dfn[u]);
}
if (d[u] > d[v]) swap(u,v);
ans += tr.asks(1,dfn[u],dfn[v]);
return ans;
}
signed main(){
scanf("%lld",&n);
for (int i=1,x,y;i<n;i++)
scanf("%lld %lld",&x,&y),insert(x,y),insert(y,x);
for (int i=1;i<=n;i++) scanf("%lld",&a[i]);
dfshv(1,0,1);
dfs(1,1);
tr.build(1,1,n);
scanf("%lld",&m);
for (int i=1;i<=m;i++){
char c[10];
int x,y;
scanf("%s %lld %lld",c,&x,&y);
if (c[1] == 'H') tr.change(1,dfn[x],y);
if (c[1] == 'M') printf("%lld\n",qmax(x,y));
if (c[1] == 'S') printf("%lld\n",qsum(x,y));
}
return 0;
}