第一次学树剖,推荐一篇大佬博客:https://www.cnblogs.com/KatouKatou/p/9557540.html
树剖关键代码:
void dfs1(int u,int f){
sze[u]=1;
for (int i=head[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(v==f)continue;
dep[v]=dep[u]+1;fa[v]=u;
dfs1(v,u);
sze[u]+=sze[v];
if (sze[v]>sze[son[u]])son[u]=v;
}
}
void dfs2(int u,int Top)
{
dfn[u]=++cnt;///按重链优先编号,使重儿子连续
rk[cnt]=u;
top[u]=Top;
if(son[u]) dfs2(son[u],Top);///继续递归u的重儿子,故重链顶端还是Top
for(int i=head[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u])///既不是u的父亲,也不是u的重儿子时,说明v只可能是u的轻儿子
dfs2(v,v);///轻儿子的Top还是自己
}
}
核心:树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。 关键是利用top的加速。
例题:https://www.luogu.org/problemnew/show/P2590
分析:显然是线段树维护树链剖分,支持单点修改,区间查询最大值,区间求和。
每个数组的含义:
sze[i]:以i为根结点的子树的所有结点数。
fa[i]:i结点的父结点
son[i]:i结点的重儿子
top[i]:i结点的重链顶端结点
dep[i]:i结点的深度,dep[1]=1
dfn[i]: i结点的dfs序编号
rk[i]:dfs序编号为i的结点在原来树中对应的编号
Ac code:
///线段树维护树链剖分
#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+10;
int a[maxn],tot,cnt,n;
int head[maxn];
struct Edge{
int v,nxt;
}edge[maxn<<1];
int sze[maxn],fa[maxn],son[maxn],top[maxn],dep[maxn];
int dfn[maxn],rk[maxn];
const int INF=1e9;
struct Tree{
int _max;
int sum;
}tree[maxn<<2];
void init()
{
cnt=0;
tot=0;
dep[1]=1;
memset(head,-1,sizeof head);
}
void addedge(int u,int v)
{
edge[tot].v=v;
edge[tot].nxt=head[u];
head[u]=tot++;
}
void dfs1(int u,int f){
sze[u]=1;
for (int i=head[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(v==f)continue;
dep[v]=dep[u]+1;fa[v]=u;
dfs1(v,u);
sze[u]+=sze[v];
if (sze[v]>sze[son[u]])son[u]=v;
}
}
void dfs2(int u,int Top)
{
dfn[u]=++cnt;///按重链优先编号,使重儿子连续
rk[cnt]=u;
top[u]=Top;
if(son[u]) dfs2(son[u],Top);
for(int i=head[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u])
dfs2(v,v);
}
}
void pushup(int rt)
{
tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum;
tree[rt]._max=max(tree[rt<<1]._max,tree[rt<<1|1]._max);
}
void buildtree(int rt,int l,int r)
{
if(l==r){
tree[rt].sum=a[rk[l]];
tree[rt]._max=a[rk[l]];///注意是在原来树中对应的编号
return;
}
int mid=(l+r)>>1;
buildtree(rt<<1,l,mid);
buildtree(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int p,int val)
{
if(l==r){
tree[rt].sum=val;
tree[rt]._max=val;
return;
}
int mid=(l+r)>>1;
if(p<=mid) update(rt<<1,l,mid,p,val);
else update(rt<<1|1,mid+1,r,p,val);
pushup(rt);
}
int query_max(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R){
return tree[rt]._max;
}
int ans=-INF;
int mid=(l+r)>>1;
if(mid>=L) ans=max(ans,query_max(rt<<1,l,mid,L,R));
if(mid<R) ans=max(ans,query_max(rt<<1|1,mid+1,r,L,R));
pushup(rt);
return ans;
}
int query_sum(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R){
return tree[rt].sum;
}
int ans=0;
int mid=(l+r)>>1;
if(mid>=L) ans+=query_sum(rt<<1,l,mid,L,R);
if(mid<R) ans+=query_sum(rt<<1|1,mid+1,r,L,R);
pushup(rt);
return ans;
}
int qmax(int u,int v)
{
int ans=-INF;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans=max(ans,query_max(1,1,n,dfn[top[u]],dfn[u]));///更新重链顶端到u的每个结点的值
u=fa[top[u]];///top[u]!=top[v],此时深的结点u应跳到top[u]的父节点
}
if(dep[u]<dep[v]) swap(u,v);
ans=max(ans,query_max(1,1,n,dfn[v],dfn[u]));
return ans;
}
int qsum(int u,int v)
{
int ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans+=query_sum(1,1,n,dfn[top[u]],dfn[u]);
u=fa[top[u]];///top[u]!=top[v],此时深的结点u应跳到top[u]的父节点
}
if(dep[u]<dep[v]) swap(u,v);
ans+=query_sum(1,1,n,dfn[v],dfn[u]);
return ans;
}
int main()
{
int m,u,v;
scanf("%d",&n);
init();
for(int i=1;i<=n-1;i++){
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
fa[1]=1;
dfs1(1,-1);
dfs2(1,1);
buildtree(1,1,n);
scanf("%d",&m);
char op[10];
while(m--){
scanf("%s%d%d",op,&u,&v);
if(op[1]=='H'){
update(1,1,n,dfn[u],v);
}
else if(op[1]=='M'){
printf("%d\n",qmax(u,v));
}
else{
printf("%d\n",qsum(u,v));
}
}
return 0;
}
再来个模板题
例题:https://www.luogu.org/problemnew/show/P3384
分析:需要完成四种操作。
1、路径修改
2、路径查询
3、子树修改
4、子树查询
1,2直接可以树链剖分+线段树完成
3,4不需要树剖,找到需要查询子树的根结点和该子树dfs最后一个结点的编号即可,即对区间[dfn[u],dfn[u]+sze[u]-1]进行操作
dfn[u]为dfs序编号。
Ac code:
/P3384
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=1e5+10;
int head[maxn],a[maxn],n;
struct Edge{
int v,nxt;
}edge[maxn<<1];
int tot,cnt;
int fa[maxn],son[maxn],dfn[maxn],rk[maxn],sze[maxn],dep[maxn],top[maxn];
int p;
struct Tree{
int val,lazy;
}tree[maxn<<2];
void init()
{
tot=cnt=0;
memset(head,-1,sizeof head);
}
void addedge(int u,int v)
{
edge[tot].v=v;
edge[tot].nxt=head[u];
head[u]=tot++;
}
void dfs1(int u,int faa)
{
sze[u]=1;
for(int i=head[u];~i;i=edge[i].nxt)
{
int v=edge[i].v;
if(v==faa) continue;
fa[v]=u;
dep[v]=dep[u]+1;
dfs1(v,u);
sze[u]+=sze[v];
if(sze[v]>sze[son[u]])
son[u]=v;
}
}
void dfs2(int u,int Top)
{
dfn[u]=++cnt;
rk[cnt]=u;
top[u]=Top;
if(son[u]) dfs2(son[u],Top);
for(int i=head[u];~i;i=edge[i].nxt){
int v=edge[i].v;
if(v!=fa[u]&&v!=son[u]){
dfs2(v,v);
}
}
}
int lca(int u,int v)
{
while(top[u]!=top[v])
{
if(dep[top[u]]>dep[top[v]])
u=fa[top[u]];
else
v=fa[top[v]];
}
return dep[u]<dep[v]?u:v;
}
void pushup(int rt)
{
tree[rt].val=tree[rt<<1].val+tree[rt<<1|1].val;
}
void pushdown(int rt,int l,int r)
{
int len=r-l+1;
if(tree[rt].lazy){
tree[rt<<1].lazy+=tree[rt].lazy;
tree[rt<<1].val+=tree[rt].lazy*(len-(len>>1));
tree[rt<<1|1].lazy+=tree[rt].lazy;
tree[rt<<1|1].val+=tree[rt].lazy*(len>>1);
tree[rt<<1].val%=p;
tree[rt<<1|1].val%=p;
tree[rt].lazy=0;
}
}
void buildtree(int rt,int l,int r)
{
if(l==r){
tree[rt].val=a[rk[l]]%p;
tree[rt].lazy=0;
return;
}
int mid=(l+r)>>1;
buildtree(rt<<1,l,mid);
buildtree(rt<<1|1,mid+1,r);
pushup(rt);
}
void update(int rt,int l,int r,int L,int R,int val)
{
if(L<=l&&r<=R){
tree[rt].lazy+=val;
tree[rt].val+=val*(r-l+1);
return;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
if(mid>=L) update(rt<<1,l,mid,L,R,val);
if(mid<R) update(rt<<1|1,mid+1,r,L,R,val);
pushup(rt);
}
int query(int rt,int l,int r,int L,int R)
{
if(L<=l&&r<=R){
return tree[rt].val%p;
}
pushdown(rt,l,r);
int mid=(l+r)>>1;
int ans=0;
if(mid>=L) ans=(ans+query(rt<<1,l,mid,L,R))%p;
if(mid<R) ans=(ans+query(rt<<1|1,mid+1,r,L,R))%p;
return ans;
}
int qrange(int u,int v)
{
int ans=0;
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
ans=(ans+query(1,1,n,dfn[top[u]],dfn[u]))%p;
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
ans=(ans+query(1,1,n,dfn[v],dfn[u]))%p;
return ans;
}
void uprange(int u,int v,int val)
{
while(top[u]!=top[v]){
if(dep[top[u]]<dep[top[v]]) swap(u,v);
update(1,1,n,dfn[top[u]],dfn[u],val);
u=fa[top[u]];
}
if(dep[u]<dep[v]) swap(u,v);
update(1,1,n,dfn[v],dfn[u],val);
}
int qson(int u)
{
return query(1,1,n,dfn[u],dfn[u]+sze[u]-1)%p;
}
void upson(int u,int val)
{
update(1,1,n,dfn[u],dfn[u]+sze[u]-1,val);
}
int main()
{
int m,r;
scanf("%d%d%d%d",&n,&m,&r,&p);
init();
dep[r]=1;
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
int u,v;
for(int i=1;i<n;i++)
{
scanf("%d%d",&u,&v);
addedge(u,v);
addedge(v,u);
}
fa[r]=0;
dfs1(r,0);
dfs2(r,r);
buildtree(1,1,n);
for(int i=1;i<=n;i++) cout<<"top="<<rk[i]<<' '<<sze[i]<<endl;
int op,x,y,z;
while(m--){
scanf("%d",&op);
if(op==1){
scanf("%d%d%d",&x,&y,&z);
uprange(x,y,z);
}
else if(op==2){
scanf("%d%d",&x,&y);
printf("%d\n",qrange(x,y)%p);
}
else if(op==3){
scanf("%d%d",&x,&z);
upson(x,z);
}
else{
scanf("%d",&x);
printf("%d\n",qson(x)%p);
}
}
return 0;
}