树链剖分,正如其字面意思
就是把树上的路径剖成一条条链
对于树上路径的区间修改查询类问题,直接dfs路径显然不能承受
而树链剖分的思想就是
将树上路径剖成一条条链
每条链都以一段连续区间标号以便在线段树上维护路径信息
首先是树链剖分的一些定义:
size[u] 子树大小;dep[u] 节点深度; fa[u] 父节点
这些没什么好解释的
son[u] :结点u的重儿子
重儿子:u的所有儿子中 size最大的那个儿子(若有相同则任选一个)
轻儿子:对于一个节点u,除了其重儿子其他全是轻儿子
重边:即连接重儿子与其父亲的边
重链:相邻重边连接形成的路径
top[u]: 结点u所在重链中深度最小的结点(或解释为其所在重链链顶)
对于每个结点u
若u是轻儿子,则
t
o
p
[
u
]
=
u
top[u]=u
top[u]=u
若u是重儿子,
t
o
p
[
u
]
=
t
o
p
[
f
a
[
u
]
]
top[u]=top[fa[u]]
top[u]=top[fa[u]]
特别的
t
o
p
[
r
t
]
=
r
t
top[rt]=rt
top[rt]=rt
以上可以由两次dfs预处理出
n
u
m
[
]
num[]
num[]记录树上结点对应线段树中的编号
p
o
s
[
]
pos[]
pos[]记录线段树上的编号对应原树上的哪个结点
void dfs1(int u,int pa)
{
size[u]=1;//初始化size,表示只有自己
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==pa) continue;
dep[v]=dep[u]+1; fa[v]=u;
dfs1(v,u);
size[u]+=size[v];//另u的size加上其儿子的size
if(size[v]>size[son[u]]) son[u]=v;//判断重儿子
}
}
void dfs2(int u,int tp)
{
num[u]=++cnt; pos[cnt]=u; //记录对应编号
top[u]=tp;//记录链顶
if(son[u]) dfs2(son[u],tp);//如果有重儿子则先处理重儿子
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;//处理其他轻儿子
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
dep[rt]=1;
dfs1(rt,-1); dfs2(rt,rt)
树链剖分の应用
处理出上述信息后,来看看怎么将其转化为线段树上的区间操作吧
洛谷 P2590 [ZJOI2008]树的统计
I. CHANGE u t : 把结点u的权值改为t
II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值
III. QSUM u v: 询问从点u到点v的路径上的节点的权值和
题目分析:
对于CHANGE操作
直接在线段树上对
n
u
m
[
u
]
num[u]
num[u]进行单点修改
对于QSUM操作
1.比较top[u]和top[v]是否相同(即两者是否在同一条重链上)
2.若
t
o
p
[
u
]
!
=
t
o
p
[
v
]
top[u]!=top[v]
top[u]!=top[v],先假设dep[u]较大
不难发现,一条重链上的结点在线段树中的编号一定是连续的
所以对区间
[
l
l
=
n
u
m
[
t
o
p
[
u
]
]
,
r
r
=
n
u
m
[
u
]
]
[ll=num[ top[u] ],rr=num[u]]
[ll=num[top[u]],rr=num[u]]查询区间和,并以ans累加记录
然后另
u
=
f
a
[
t
o
p
[
u
]
]
u=fa[ top[u] ]
u=fa[top[u]],回到步骤1直到
t
o
p
[
u
]
=
=
t
o
p
[
v
]
top[u]==top[v]
top[u]==top[v]
有
t
o
p
[
u
]
=
=
t
o
p
[
v
]
top[u]==top[v]
top[u]==top[v]以后
假设dep[u]较大,再次查询区间
[
l
l
=
n
u
m
[
v
]
,
r
r
=
n
u
m
[
u
]
]
[ll=num[ v ],rr=num[u]]
[ll=num[v],rr=num[u]]的区间和累计如ans,最后返回ans
QMAX操作也是类似
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
void print(int x)
{
if(x<0){putchar('-');x=-x;}
if(x>9)print(x/10);
putchar(x%10+'0');
}
int n,t;
int tot;
struct node{int v,nxt;}E[100010];
int head[100010];
int w[100010];
int cnt;
int dep[100010],fa[100010];
int size[100010],son[100010];
int top[100010],num[100010],pre[100010];
int sum[400010],maxn[400010];
char ss[20];
void add(int u,int v)
{
E[++tot].v=v;
E[tot].nxt=head[u];
head[u]=tot;
}
void dfs1(int u,int pa)
{
size[u]=1;
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==pa) continue;
dep[v]=dep[u]+1; fa[v]=u;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
num[u]=++cnt; pre[cnt]=u; top[u]=tp;
if(son[u]) dfs2(son[u],tp);
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
void push(int p)
{
maxn[p]=max(maxn[p<<1],maxn[p<<1|1]);
sum[p]=sum[p<<1]+sum[p<<1|1];
}
void build(int s,int t,int p)
{
if(s==t){ maxn[p]=sum[p]=w[ pre[s] ]; return; }
int mid=(s+t)>>1;
build(s,mid,p<<1); build(mid+1,t,p<<1|1);
push(p);
}
void update(int u,int w,int s,int t,int p)
{
if(s==t){maxn[p]=sum[p]=w;return;}
int mid=(s+t)>>1;
if(u<=mid) update(u,w,s,mid,p<<1);
else update(u,w,mid+1,t,p<<1|1);
push(p);
}
int getmax(int ll,int rr,int s,int t,int p)
{
if(ll<=s&&t<=rr) return maxn[p];
int mid=(s+t)>>1;
int ans=-1e9;
if(ll<=mid) ans=max(ans, getmax(ll,rr,s,mid,p<<1) );
if(rr>mid) ans=max(ans, getmax(ll,rr,mid+1,t,p<<1|1) );
return ans;
}
int qmax(int u,int v)
{
int ans=-1e9;
while (top[u]!=top[v])
{
if (dep[top[u]]<dep[top[v]])swap(u,v);
ans=max( ans,getmax(num[top[u]],num[u],1,n,1) );
u=fa[top[u]];
}
if (dep[u]<dep[v])swap(u,v);
ans=max(ans,getmax(num[v],num[u],1,n,1));
return ans;
}
int getsum(int ll,int rr,int s,int t,int p)
{
if(ll<=s&&t<=rr) return sum[p];
int mid=(s+t)>>1;
int ans=0;
if(ll<=mid) ans+=getsum(ll,rr,s,mid,p<<1) ;
if(rr>mid) ans+=getsum(ll,rr,mid+1,t,p<<1|1) ;
return ans;
}
int qsum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])//若u和v不在一条重链上
{
if( dep[ top[u] ] < dep[ top[v] ]) swap(u,v);//取top深度较大的一方
ans+=getsum(num[top[u]],num[u],1,n,1);//更新u到其top
u=fa[top[u]];//另u跳到其top的父亲
}
if(dep[u]<dep[v]) swap(u,v);//在一条重链上,直接区间更新
ans+=getsum(num[v],num[u],1,n,1);
return ans;
}
int main()
{
n=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
add(u,v);add(v,u);
}
for(int i=1;i<=n;i++)
w[i]=read();
dep[1]=1; fa[1]=1;
dfs1(1,-1); dfs2(1,1);
build(1,n,1);//预处理——树剖
t=read();
while(t--)
{
scanf("%s",&ss);
int x=read(),y=read();
if(ss[1]=='H') update(num[x],y,1,n,1);//直接更新num[x]
else if(ss[1]=='M') print(qmax(x,y)),printf("\n");
else if(ss[1]=='S') print(qsum(x,y)),printf("\n");
}
return 0;
}
洛谷P3178 [HAOI2015]树上操作
操作 1 :把某个节点 x 的点权增加 a 。
操作 2 :把某个节点 x 为根的子树中所有点的点权都增加 a 。
操作 3 :询问某个节点 x 到根的路径中所有点的点权和。
题目分析:
这题与上面唯一不同的是要更新u的所有子树
不难发现其实u及其所有子节点在线段树上的编号
就是
[
l
l
=
n
u
m
[
u
]
,
r
r
=
n
u
m
[
u
]
+
s
i
z
e
[
u
]
−
1
]
[ll=num[u],rr=num[u]+size[u]-1]
[ll=num[u],rr=num[u]+size[u]−1]的连续区间
代码就不放了
洛谷P3038 [USACO11DEC]牧草种植Grass Planting
给出一棵n个节点的树,有m个操作
操作为将一条路径上的边权加一或询问某条边的权值
题目分析
树剖维护边权信息
对于每个节点u,他可以有很多孩子节点,但只能有一个父节点
显然我们可以用u的点权表示u->fa[u]这条道路的边权
这样就可以直接把边权的修改转换为点权的修改了
不过对于update(u到v的路径)我们要稍作修改
void update(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(num[top[u]],num[u],1,n,1);
u=fa[top[u]];
}
if(u==v) return;//**高亮**,若当前u==v则不修改
if(dep[u]>dep[v]) swap(u,v);
update(num[u]+1,num[v],1,n,1);//**高亮**,这里的左端点是num[u]+1
}
上面update到达最后一次更新的时候(最后一行)
当前的u必定是原来u和v的lca,而我们结点权值代表的是它到其父亲的边权
从u到v的最短路径必然不会经过他们lca的祖先,所以u和v的lca(也就是当前的u)权值不能一起更新
(有一点绕,可以把样例的图画出来模拟一下)
注意上面u==v的时候u也是他们的lca,所以要直接返回
#include<iostream>
#include<vector>
#include<algorithm>
#include<queue>
#include<cstring>
#include<cstdio>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return f*x;
}
void print(int x)
{
if(x<0){putchar('-');x=-x;}
if(x>9)print(x/10);
putchar(x%10+'0');
}
const int maxn=500010;
int n,m;
int tot,cnt;
struct node{int v,nxt;}E[maxn];
int head[maxn];
int dep[maxn],fa[maxn],son[maxn];
int size[maxn],top[maxn],path[maxn];
int sum[maxn],add[maxn];
int num[maxn];
void adde(int u,int v)
{
E[++tot].nxt=head[u];
E[tot].v=v;
head[u]=tot;
}
void dfs1(int u,int pa)
{
size[u]=1;
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==pa) continue;
dep[v]=dep[u]+1; fa[v]=u;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp; num[u]=++cnt;
if(son[u]) dfs2(son[u],tp);
for(int i=head[u];i;i=E[i].nxt)
{
int v=E[i].v;
if(v==fa[u]||v==son[u])continue;
dfs2(v,v);
}
}
void push(int mid,int s,int t,int p)
{
add[p<<1]+=add[p]; add[p<<1|1]+=add[p];
sum[p<<1]+=add[p]*(mid-s+1);
sum[p<<1|1]+=add[p]*(t-mid);
add[p]=0;
}
void update(int ll,int rr,int s,int t,int p)
{
if(ll<=s&&t<=rr){sum[p]+=t-s+1;add[p]++;return;}
int mid=s+t>>1;
if(add[p]) push(mid,s,t,p);
if(ll<=mid) update(ll,rr,s,mid,p<<1);
if(rr>mid) update(ll,rr,mid+1,t,p<<1|1);
sum[p]=sum[p<<1]+sum[p<<1|1];
}
int get(int u,int s,int t,int p)
{
if(s==t)return sum[p];
int mid=s+t>>1;
if(add[p]) push(mid,s,t,p);
if(u<=mid) return get(u,s,mid,p<<1);
else return get(u,mid+1,t,p<<1|1);
}
void update(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(num[top[u]],num[u],1,n,1);
u=fa[top[u]];
}
if(u==v) return;//**高亮**,若当前u==v则不修改
if(dep[u]>dep[v]) swap(u,v);
update(num[u]+1,num[v],1,n,1);//**高亮**,这里的左端点是num[u]+1
}
int main()
{
n=read();m=read();
for(int i=1;i<n;i++)
{
int u=read(),v=read();
adde(u,v);adde(v,u);
}
dep[1]=1;
dfs1(1,-1);dfs2(1,1);
while(m--)
{
char ss; scanf("%s",&ss); int u=read(),v=read();
if(ss=='P') update(u,v);
else if(ss=='Q')
{
if(dep[u]<dep[v]) swap(u,v);
print(get(num[u],1,n,1));printf("\n");
//设深度较大的点为u,则fa[u]必定等于v,所以直接查询u的权值
}
}
return 0;
}