重链剖分
1.(图中白色的)重子节点:在其父节点中的所有的子节点子树最大的(就是底下节点最多的)
2.(图中橙色的)轻子节点:不是在其父节点中的所有的子节点子树最大的。(一条链的开始)
关于树上区间修改(或者查询)
通过dfs序建一个线段树。
比如修改图中节点6-节点4。
就是做O(log n)次修改,修改节点6-节点4经过的所有链上的点(点要经过两点间的最短路径
)
Top数组
需要一个top数组记录一条链的轻节点。
关于DFS1
(求dep[u]深度,根深度0、fa[u]:节点u的父亲、hson[u]节点u的重儿子,如果u是叶子节点,则值为-1,sz[u]为根节点为子树的节点数量)
初始代码
#include<bits/stdc++.h>
using namespace std;
const int N = 200000+10;
vector<int> G[N] ;
int dep[N],fa[N],hson[N],sz[N],top[N],dfn[N];
dfs1代码
void dfs1(int u,int p){//u当前访问到的,p父节点。
fa[u] = p;
sz[u] = 1;
dep[u] = dep[p]+1;
hson[u] = -1;
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(hson[u] == -1|| sz[v] > sz[hson[u]]){
hson[u] = v;
}
}
}
DFS2
求top[u] ,dfn[u] (top数组记录u所在链上的轻节点。dfn[u]即u的dfs序)
注意:这个dfs要优先遍历重儿子
DFS2代码
int cnt;//记录是第几次访问
void dfs2(int u,int p){
if(u == 1)top[1] = 1;//一定要加!!!
//先遍历重儿子
cnt++;
dfn[u] = cnt;
if(hson[u] == -1)return;
top[hson[u]] = top[u];
dfs2(hson[u],u);
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p || hson[u] == v)continue;
top[v] = v;
dfs2(v,u);
}
}
注意事项
一定要赋top[1]初始值 ,top[1] = 1;!!!
写LCA(用top数组求最近公共祖先)
如果两个节点top值一样就在同一条链上,那么公共祖先就是dep小的那一个。
如果不一样,那就top值小的往上走,走到top的父亲上
lca代码
int lca(int u,int v){
if(top[u] == top[v])return dep[u]<dep[v] ? u : v;//那个深度小就取哪个。
if(dep[top[u]]<dep[top[v]]) return lca(u,fa[top[v]]);
else return lca(fa[top[u]],v);
//因为他是要跳到top的父亲上所以要比较两个top的深度;
}
小试牛刀
用树链剖分求最近公共祖先
1000. 最近公共祖先
限制条件
时间限制: 1000 ms, 空间限制: 256 MB
题目描述
给定一棵 n个节点的树,根节点为1 。
再给出 q次询问:回答树上任意两点的最近公共祖先是哪个节点。
输入格式
第一行n有一个整数 ,表示节点个数。
接下来n-1 行,每行有两个整数u , v,表示树上的一条边。
接下来一行有一个整数q ,表示询问次数。
接下来 q行,每行有两个整数 u, v,表示询问的节点编号。
输出格式
对于每次询问输出一行,一个整数,表示最近公共祖先的节点编号。
样例
样例输入 复制
5
1 2
1 3
3 4
3 5
5
2 1
4 2
4 5
3 2
1 5
样例输出 复制
1
1
3
1
1
AC代码
#include<bits/stdc++.h>
using namespace std;
const int N = 200000+10;
vector<int> G[N] ;
int dep[N],fa[N],hson[N],sz[N],top[N],dfn[N];
void dfs1(int u,int p){//u当前访问到的,p父节点。
fa[u] = p;
sz[u] = 1;
dep[u] = dep[p]+1;
hson[u] = -1;
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(hson[u] == -1|| sz[v] > sz[hson[u]]){
hson[u] = v;
}
}
}
int cnt;//记录是第几次访问
void dfs2(int u,int p){
if(u == 1)top[1] = 1;
//先遍历重儿子
cnt++;
dfn[u] = cnt;
if(hson[u] == -1)return;
top[hson[u]] = top[u];
dfs2(hson[u],u);
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p || hson[u] == v)continue;
top[v] = v;
dfs2(v,u);
}
}
int lca(int u,int v){
if(top[u] == top[v])return dep[u]<dep[v] ? u : v;//那个深度小就取哪个。
if(dep[top[u]]<dep[top[v]]) return lca(u,fa[top[v]]);
else return lca(fa[top[u]],v);
//因为他是要跳到top的父亲上所以要比较两个top的深度;
}
int main(){
int n;
cin>>n;
for(int i = 1;i <= n-1;i++){
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,1);
dfs2(1,1);
int q;
cin>>q;
while(q--){
int u,v;
cin>>u>>v;
cout<<lca(u,v)<<endl;
}
return 0;
}
进阶例题
建树
按照dfs序建树。
res[i]表示dfs序为i的节点编号。
void build(int index,int begin,int end){
if(begin == end){
sgt[index].sum = a[res[begin]];
return ;
}
int mid = (begin+end)/2;
build(index*2,begin,mid);
build(index*2+1,mid+1,end);
sgt[index].sum = sgt[index*2].sum+sgt[index*2+1].sum;
}
只有sgt[index].sum = a[res[begin]];这个要在原线段树代码上改动。
操作1的实现
代码
if(op == 1){
int id, y;
cin>>id>>y;
update(1,1,n,dfn[id],dfn[id],y);
}
核心思路:将id转化为dfs序
操作2的实现
代码
if(op == 2){
int id, y;
cin>>id>>y;
update(1,1,n,dfn[id],dfn[id]+sz[id]-1,y);
}
核心思路:子树的dfs序一定是连续的。所以就是修改dfs[id]到dfs[id]+sz[id]-1这一个区间。
注意:要把+sz[id]-1拿出来,不是dfs[id+sz[id]-1]!
操作3的实现
代码
long long getsum(int x){
long long sum = 0;
while(1){
sum+=query(1,1,n,dfn[top[x]],dfn[x]);//注意是top[x]到x因为top的dfs序更小
if(top[x] == 1)break;
x = fa[top[x]];
}
return sum;
}
if(op == 3){
int id;
cin>>id;
cout<<getsum(id)<<endl;
}
核心思路:利用top不断追溯到根节点。期间不断求区间(dfn[top[x]],dfn[x])的和。然后将x变为top[x]的父亲,一直到top[x]为根节点。
AC代码
#include<bits/stdc++.h>
using namespace std;
int n,m;
const int N = 200000+10;
vector<int> G[N] ;
int dep[N],fa[N],hson[N],sz[N],top[N],dfn[N],res[N];
struct Node{
long long sum,add;
};
Node sgt[N*4];
long long a[N];
void build(int index,int begin,int end){
if(begin == end){
sgt[index].sum = a[res[begin]];
return ;
}
int mid = (begin+end)/2;
build(index*2,begin,mid);
build(index*2+1,mid+1,end);
sgt[index].sum = sgt[index*2].sum+sgt[index*2+1].sum;
}
void push_up(int index){
sgt[index].sum = sgt[index*2].sum + sgt[index*2+1].sum;
}
void push_down(int index,int begin,int end){
sgt[index*2].add += sgt[index].add;
sgt[index*2+1].add += sgt[index].add;
sgt[index].sum += sgt[index].add * (end - begin +1);
sgt[index].add = 0;
}
void update(int index,int begin,int end,int left,int right,int x){
if(left == begin&& right == end){
sgt[index].add += x;
return ;
}
sgt[index].sum += 1LL*x*(right -left + 1);
push_down(index,begin,end);
int mid = (begin+end)/2;
if(right <= mid) return update(index*2,begin,mid,left,right,x);
else if(left > mid)return update(index*2+1,mid+1,end,left,right,x);
else{
update(index*2,begin,mid,left,mid,x);
update(index*2+1,mid+1,end,mid+1,right,x);
}
}
long long query(int index,int begin,int end,int left,int right){
if(left == begin && right == end){
return sgt[index].sum + sgt[index].add*(end -begin+1);
}
push_down(index,begin,end);
int mid = (begin+end)/2;
if(right <= mid) return query(index*2,begin,mid,left,right);
else if(left > mid)return query(index*2+1,mid+1,end,left,right);
else{
return query(index*2,begin,mid,left,mid)+query(index*2+1,mid+1,end,mid+1,right);
}
}
//____________________________________分割线。
void dfs1(int u,int p){//u当前访问到的,p父节点。
fa[u] = p;
sz[u] = 1;
dep[u] = dep[p]+1;
hson[u] = -1;
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p)continue;
dfs1(v,u);
sz[u]+=sz[v];
if(hson[u] == -1|| sz[v] > sz[hson[u]]){
hson[u] = v;
}
}
}
int cnt;//记录是第几次访问
void dfs2(int u,int p){
if(u == 1)top[1] = 1;
//先遍历重儿子
cnt++;
dfn[u] = cnt;
res[cnt] = u;
if(hson[u] == -1)return;
top[hson[u]] = top[u];
dfs2(hson[u],u);
for(int i = 0;i < G[u].size();i++){
int v = G[u][i];
if(v == p || hson[u] == v)continue;
top[v] = v;
dfs2(v,u);
}
}
long long getsum(int x){
long long sum = 0;
while(1){
sum+=query(1,1,n,dfn[top[x]],dfn[x]);//注意是top[x]到x因为top的dfs序更小
if(top[x] == 1)break;
x = fa[top[x]];
}
return sum;
}
int main(){
cin>>n>>m;
for(int i = 1;i <= n;i++){
cin>>a[i];
}
for(int i = 1;i <= n-1;i++){
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,1);
dfs2(1,1);
build(1,1,n);
while(m--){
int op;
cin>>op;
if(op == 1){
int id, y;
cin>>id>>y;
update(1,1,n,dfn[id],dfn[id],y);
}
if(op == 2){
int id, y;
cin>>id>>y;
update(1,1,n,dfn[id],dfn[id]+sz[id]-1,y);
}
if(op == 3){
int id;
cin>>id;
cout<<getsum(id)<<endl;
}
}
return 0;
}
总结
树链剖分就是利用将树转换成一条条链来处理树上问题