树链剖分
树链剖分,就是将一颗树划分成多条链,每次在操作的时候则是对链进行修改。
基本概念:
1、重儿子:这个节点的儿子中,子树最大的就是重儿子
2、轻儿子:除了重儿子之外的其它子节点
叶节点是没有重儿子和轻儿子的。
3、轻边:节点与轻儿子连接的边
4、重边:节点与重儿子连接的边
5、重链:均由重儿子连接而成的一条链
6、轻链:均由轻儿子连接而成的一条链
一、预处理信息
要运用树链剖分,就要预处理出所需要的信息。
dep:节点的深度 fa:节点的父亲
zson:节点的重儿子 num:节点的子树大小
这几个信息是很好处理的,用dfs即可 O(n) 处理。
void dfs1(int root,int last) { //求dep,fa,zson,num
dep[root]=dep[last]+1;
fa[root]=last;num[root]=1;
for (int t=first[root];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==last) continue;
dfs1(h,root);
num[root]+=num[h];
if (num[h]>num[zson[root]]) zson[root]=h;
}
}
接下来,就是要处理每条链的顶点st。也是用dfs进行处理。因为重链是固定的,所以我们在dfs的过程中是优先遍历重儿子,再遍历轻儿子。
void dfs2(int now,int topx) { //求st
st[now]=topx;
dfn[now]=++dnow;//求dfs序,对之后的线段树维护有用
if (!zson[now]) return;
dfs2(zson[now],topx);
for (int t=first[now];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==fa[now]) continue;
if (h==zson[now]) continue;
dfs2(h,h);
}
}
到此,需要的信息就已经处理好了。
二、利用树链剖分求LCA
基本思路:
(1) 如果两个节点在同一条链上,说明dep较小的节点就是LCA;
(2) 否则的话,比较两条链的顶点的深度,选择顶点dep较大的点向上条,重复操作,直到出现了(1)这种情况。
正确性证明:
首先,在我们对一颗树进行剖分的时候,每条链是没有重复的顶点和边的。如果两个节点在同一条链上,那么其中肯定有他们的LCA。链是由上到下的,每条链都是由父亲节点转化到子节点而来的。
如果两个节点不在一条链上,那么就需要一个点向上寻找LCA。为什么选择顶点dep较大的节点往上跳,是为了防止LCA错误。因为可以肯定,选择dep较小的点往上跳,上面节点所在的链必然不会包含dep较大的点,那么就会陷入死循环。当然,可以画图进行理解。
int LCA(int x,int y) {
int d1=x,d2=y;
while (st[d1]!=st[d2]) {
//printf("now:d1=%d,d2=%d\n",d1,d2);
if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
//printf("%d->%d\n",d1,fa[st[d1]]);
d1=fa[st[d1]];
}
return dep[d1]<dep[d2]?d1:d2;
}
AC Code:
#include<bits/stdc++.h>
#include<string>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
using namespace std;
inline int read() {
int x=0,f=1;char s=getchar();
while (s>'9'||s<'0') {
if (s=='-') f=-f;
s=getchar();
}
while (s>='0'&&s<='9') {
x=(x<<1)+(x<<3)+s-'0';
s=getchar();
}
return x*f;
}
const int N = 1e5+10;
int n,u,v,s,m,first[N],cnt,dep[N],fa[N],zson[N],num[N],st[N];
struct edge{
int to,nxt;
}edges[N*2];
void add(int u,int v) {
edges[++cnt].to=v;
edges[cnt].nxt=first[u];
first[u]=cnt;
}
void dfs1(int root,int last) { //求dep,fa,zson,num
dep[root]=dep[last]+1;
fa[root]=last;num[root]=1;
for (int t=first[root];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==last) continue;
dfs1(h,root);
num[root]+=num[h];
if (num[h]>num[zson[root]]) zson[root]=h;
}
}
void dfs2(int now,int topx) { //求st
st[now]=topx;
if (!zson[now]) return;
dfs2(zson[now],topx);
for (int t=first[now];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==fa[now]) continue;
if (h==zson[now]) continue;
dfs2(h,h);
}
}
int LCA(int x,int y) {
int d1=x,d2=y;
while (st[d1]!=st[d2]) {
//printf("now:d1=%d,d2=%d\n",d1,d2);
if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
//printf("%d->%d\n",d1,fa[st[d1]]);
d1=fa[st[d1]];
}
return dep[d1]<dep[d2]?d1:d2;
}
int main() {
n=read();m=read();s=read();
for (int i=1;i<n;++i) u=read(),v=read(),add(u,v),add(v,u);
dfs1(s,0);
//for (int i=1;i<=n;++i) printf("i=%d,dep=%d,fa=%d,zson=%d,num=%d\n",i,dep[i],fa[i],zson[i],num[i]);
dfs2(s,s);
//for (int i=1;i<=n;++i) printf("st[%d]=%d\n",i,st[i]);
for (int i=1;i<=m;++i) {
int x,y;
x=read();y=read();
printf("%d\n",LCA(x,y));
}
return 0;
}
这是利用树链剖分运行所需时间。
这是常规的倍增求LCA所花的时间。树链剖分在这道题上虽然不如tarjan算法快,但是还是十分可观的。
三、利用树链剖分解决树上修改查询问题
树链剖分的dfs2的遍历顺序十分地巧妙,因为每次是优先遍历重儿子,且每次遍历完这一整颗树才返回,导致他的dfs序有一个很奇妙的特点:每条链都是一个区间,每个子树也是一个区间。那么,子树和链上的修改查询就相当于在dfs序的一段区间内进行修改查询。那么我们便可以用线段树或者树状数组对他进行优化,这样时间复杂度就大大减小了。
这张图是从一位大佬那里搬过来的,因为本蒟蒻画图不行QAQ。如上图,可以清晰的看出每条链,每个子树都是一个区间这一特点。那么,我们便可以在dfs2的过程中,建立新节点nid,和新节点的点权nw,并保留每个节点对应的新节点,我们的链上修改和子树修改就十分的方便了
例题1
分析
这是一道模板题,要求是实现路径上的区间加,路径求和,子树上的区间加,子树求和。路径上的查询和修改,我们可以参照LCA的算法实现,将两个点之间的路径拆分成多条链或者一条链中的一部分,进行修改查询。而子树的查询和修改,可以根据他的性质,子树节点的dfs序是一个区间,且这个区间的范围是nid[root]~nid[root]+num[root]-1,其中num[root]是以root为根节点的子树的大小。这样,就可以很快的实现子树的查询与修改了。
Code
#include<bits/stdc++.h>
#include<string>
#include<vector>
#include<queue>
#include<stack>
#include<map>
#include<set>
#define int long long
using namespace std;
inline int read() {
int x=0,f=1;char s=getchar();
while (s>'9'||s<'0') {
if (s=='-') f=-f;
s=getchar();
}
while (s>='0'&&s<='9') {
x=(x<<1)+(x<<3)+s-'0';
s=getchar();
}
return x*f;
}
const int N = 1e5+10;
int n,m,s,mod,a[N],u,v,first[N],cnt,op,x,y,z;
int dep[N],fa[N],zson[N],num[N],st[N],nid[N],nw[N],dnow;
int sum[N*4],add[N*4];
struct edge{
int to,nxt;
}edges[N*2];
void Add(int u,int v) {
edges[++cnt].to=v;
edges[cnt].nxt=first[u];
first[u]=cnt;
}
void dfs1(int root,int last) {
fa[root]=last;dep[root]=dep[last]+1;
num[root]=1;
for (int t=first[root];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==last) continue;
dfs1(h,root);
num[root]+=num[h];
if (num[h]>num[zson[root]]) zson[root]=h;
}
}
void dfs2(int root,int topx) {
st[root]=topx;
nid[root]=++dnow;nw[dnow]=a[root];
if (!zson[root]) return;
else dfs2(zson[root],topx);
for (int t=first[root];t;t=edges[t].nxt) {
int h=edges[t].to;
if (h==zson[root]||h==fa[root]) continue;
dfs2(h,h);
}
}
void pushup(int k) { sum[k]=(sum[k<<1]+sum[k<<1|1])%mod;}
void build(int k,int l,int r) {
if (l==r) {
sum[k]=nw[l]%mod;
return;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
}
void pushdown(int k,int l,int r) {
if (!add[k]) return;
int mid=(l+r)>>1;
sum[k<<1]=(sum[k<<1]+(mid-l+1)*add[k]%mod)%mod;
sum[k<<1|1]=(sum[k<<1|1]+(r-mid)*add[k]%mod)%mod;
add[k<<1]=(add[k<<1]+add[k])%mod;
add[k<<1|1]=(add[k<<1|1]+add[k])%mod;
add[k]=0;
}
void modify(int k,int l,int r,int zl,int zr,int v) {
if (zl<=l&&zr>=r) {
sum[k]=(sum[k]+(r-l+1)*v%mod)%mod;
add[k]=(add[k]+v)%mod;
return;
}
pushdown(k,l,r);
int mid=(l+r)>>1;
if (mid>=zl) modify(k<<1,l,mid,zl,zr,v);
if (mid<zr) modify(k<<1|1,mid+1,r,zl,zr,v);
pushup(k);
}
int ser(int k,int l,int r,int zl,int zr) {
if (zl<=l&&zr>=r) return sum[k];
pushdown(k,l,r);
int mid=(l+r)>>1,res=0;
if (mid>=zl) res=(res+ser(k<<1,l,mid,zl,zr))%mod;
if (mid<zr) res=(res+ser(k<<1|1,mid+1,r,zl,zr))%mod;
return res%mod;
}
void add_path(int x,int y,int z) {
int d1=x,d2=y;
while (st[d1]!=st[d2]) {
if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
modify(1,1,n,nid[st[d1]],nid[d1],z);
d1=fa[st[d1]];
}
int l=min(nid[d1],nid[d2]),r=max(nid[d1],nid[d2]);
modify(1,1,n,l,r,z);
}
int ser_path(int x,int y) {
int ans=0,d1=x,d2=y;
while (st[d1]!=st[d2]) {
if (dep[st[d1]]<dep[st[d2]]) swap(d1,d2);
ans=(ans+ser(1,1,n,nid[st[d1]],nid[d1]))%mod;
d1=fa[st[d1]];
}
int l=min(nid[d1],nid[d2]),r=max(nid[d1],nid[d2]);
ans=(ans+ser(1,1,n,l,r))%mod;
return ans;
}
void add_tree(int root,int v) {
int l=nid[root],r=nid[root]+num[root]-1;
modify(1,1,n,l,r,v);
}
int ser_tree(int root) {
int l=nid[root],r=nid[root]+num[root]-1;
return ser(1,1,n,l,r);
}
signed main() {
n=read();m=read();s=read();mod=read();
for (int i=1;i<=n;++i) a[i]=read();
for (int i=1;i<n;++i) u=read(),v=read(),Add(u,v),Add(v,u);
dfs1(s,0);dfs2(s,s);
build(1,1,n);
while (m--) {
op=read();
if (op==1) {
x=read();y=read();z=read();
add_path(x,y,z);
}
if (op==2) {
x=read();y=read();
printf("%lld\n",ser_path(x,y)%mod);
}
if (op==3) {
x=read();z=read();
add_tree(x,z);
}
if (op==4) {
x=read();
printf("%lld\n",ser_tree(x)%mod);
}
}
return 0;
}