前置知识
树链剖分的思想及能解决的问题
树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。
树链剖分(树剖/链剖)有多种形式,如 重链剖分,长链剖分 和用于 Link/cut Tree 的剖分(有时被称作“实链剖分”),大多数情况下(没有特别说明时),“树链剖分”都指“重链剖分”。
重链剖分可以将树上的任意一条路径划分成不超过 O ( l o g n ) O(log~n) O(log n) 条连续的链,每条链上的点深度互不相同(即是自底向上的一条链,链上所有点的 LCA 为链的一个端点)。
重链剖分还能保证划分出的每条链上的节点 DFS 序连续,因此可以方便地用一些维护序列的数据结构(如线段树)来维护树上路径的信息。
如:
-
修改 树上两点之间的路径上 所有点的值
-
查询 树上两点之间的路径上 节点权值的 和/极值/其它(在序列上可以用数据结构维护,便于合并的信息)。
例题
题目大意
对于一棵有 n n n 个节点,节点带权值的线段树,进行三种操作共 q q q 次:
-
修改单个节点的权值;
-
查询 u u u 到 v v v 的路径上的最大权值;
-
查询 u u u 到 v v v 的路径上的最大之和。
保证 1 ≤ n ≤ 30000 1 \leq n \leq 30000 1≤n≤30000 , 0 ≤ q ≤ 200000 0 \leq q \leq 200000 0≤q≤200000 。
解法
根据题面以及以上的性质,你的线段树需要维护三种操作:
-
单点修改
-
区间查询最大值
-
区间查询和
单点修改很容易实现。
由于子树的 DFS 序连续(无论是否树剖都是如此),修改一个节点的子树只用修改这一段连续的 DFS 序区间。
问题是如何修改/查询两个节点之间的路径。
考虑我们是如何用 倍增法求解 LCA 的。首先我们 将两个节点提到同一高度,然后将两个节点一起向上跳。对于树链剖分也可以使用这样的思想。
在向上跳的过程中,如果当前节点在重链上,向上跳到重链顶端,如果当前节点不在重链上,向上跳一个节点。如此直到两节点相同。沿途更新/查询区间信息。
对于每个询问,最多经过 O ( l o g n ) O(log~n) O(log n) 条重链,每条重链上线段树的复杂度为 O ( l o g n ) O(log~n) O(log n) ,因此总时间复杂度为 O ( n l o g n + q l o g 2 n ) O(n~log~n + q~log^2~n) O(n log n+q log2 n) 。实际上重链个数很难达到 O ( l o g n ) O(log~n) O(log n) (可以用完全二叉树卡满),所以树剖在一般情况下常数较小。
给出一种代码实现:
// st 是线段树结构体
int querymax(int x, int y) {
int ret = -inf, fx = top[x], fy = top[y];
while (fx != fy) {
if (dep[fx] >= dep[fy])
ret = max(ret, st.query1(1, 1, n, dfn[fx], dfn[x])), x = fa[fx];
else
ret = max(ret, st.query1(1, 1, n, dfn[fy], dfn[y])), y = fa[fy];
fx = top[x];
fy = top[y];
}
if (dfn[x] < dfn[y])
ret = max(ret, st.query1(1, 1, n, dfn[x], dfn[y]));
else
ret = max(ret, st.query1(1, 1, n, dfn[y], dfn[x]));
return ret;
}
参考代码
#include<bits/stdc++.h>
#define lc o<<1
#define rc o<<1|1
const int maxn=60010;
const int inf=2e9;
int n,a,b,w[maxn],q,u,v;
int cur,h[maxn],nxt[maxn],p[maxn];
int siz[maxn],top[maxn],son[maxn],dep[maxn],fa[maxn],dfn[maxn],rnk[maxn],cnt;
char op[10];
inline void add_edge(int x,int y){
cur++;
nxt[cur]=h[x];
h[x]=cur;
p[cur]=y;
}
struct SegTree{
int sum[maxn*4],maxx[maxn*4];
void build(int o,int l,int r){
if(l==r){
sum[o]=maxx[o]=w[rnk[l]];
return;
}
int mid=(l+r)>>1;
build(lc,l,mid);
build(rc,mid+1,r);
sum[o]=sum[lc]+sum[rc];
maxx[o]=std::max(maxx[lc],maxx[rc]);
}
int query1(int o,int l,int r,int ql,int qr){
if(l>qr||r<ql)return -inf;
if(ql<=l&&r<=qr)return maxx[o];
int mid=(l+r)>>1;
return std::max(query1(lc,l,mid,ql,qr),query1(rc,mid+1,r,ql,qr));
}
int query2(int o,int l,int r,int ql,int qr){
if(l>qr||r<ql)return 0;
if(ql<=l&&r<=qr)return sum[o];
int mid=(l+r)>>1;
return query2(lc,l,mid,ql,qr)+query2(rc,mid+1,r,ql,qr);
}
void update(int o,int l,int r,int x,int t){
if(l==r){
maxx[o]=sum[o]=t;
return;
}
int mid=(l+r)>>1;
if(x<=mid)update(lc,l,mid,x,t);
else update(rc,mid+1,r,x,t);
sum[o]=sum[lc]+sum[rc];
maxx[o]=std::max(maxx[lc],maxx[rc]);
}
}st;
void dfs1(int o){
son[o]=-1;
siz[o]=1;
for(int j=h[o];j;j=nxt[j])
if(!dep[p[j]]){
dep[p[j]]=dep[o]+1;
fa[p[j]]=o;
dfs1(p[j]);
siz[o]+=siz[p[j]];
if(son[o] ==-1||siz[p[j]]>siz[son[o]])son[o]=p[j];
}
}
void dfs2(int o,int t){
top[o]=t;
cnt++;
dfn[o]=cnt;
rnk[cnt]=o;
if(son[o]==-1)return;
dfs2(son[o],t);
for(int j=h[o];j;j=nxt[j])
if(p[j]!=son[o]&&p[j]!=fa[o])dfs2(p[j],p[j]);
}
int querymax(int x,int y){
int ret=-inf,fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]>=dep[fy])ret=std::max(ret,st.query1(1,1,n,dfn[fx],dfn[x])),x=fa[fx];
else ret=std::max(ret,st.query1(1,1,n,dfn[fy],dfn[y])),y=fa[fy];
fx=top[x];
fy=top[y];
}
if(dfn[x]<dfn[y])ret=std::max(ret,st.query1(1,1,n,dfn[x],dfn[y]));
else ret=std::max(ret,st.query1(1,1,n,dfn[y],dfn[x]));
return ret;
}
int querysum(int x,int y){
int ret=0,fx=top[x],fy=top[y];
while(fx!=fy){
if(dep[fx]>=dep[fy])ret+=st.query2(1,1,n,dfn[fx],dfn[x]),x=fa[fx];
else ret+=st.query2(1,1,n,dfn[fy],dfn[y]),y=fa[fy];
fx=top[x];
fy=top[y];
}
if(dfn[x]<dfn[y])ret+=st.query2(1,1,n,dfn[x],dfn[y]);
else ret+=st.query2(1,1,n,dfn[y],dfn[x]);
return ret;
}
int main(){
scanf("%d",&n);
for(int i=1;i<n;i++)scanf("%d%d",&a,&b),add_edge(a,b),add_edge(b,a);
for(int i=1;i<=n;i++)scanf("%d",w+i);
dep[1]=1;
dfs1(1);
dfs2(1,1);
st.build(1,1,n);
scanf("%d",&q);
while(q--){
scanf("%s%d%d",op,&u,&v);
if(!strcmp(op,"CHANGE"))st.update(1,1,n,dfn[u],v);
if(!strcmp(op,"QMAX"))printf("%d\n",querymax(u,v));
if(!strcmp(op,"QSUM"))printf("%d\n",querysum(u,v));
}
return 0;
}