今天学习了树链剖分,记录一下。
【题目背景】
HYSBZ - 1036树的统计Count
【题目分析】
题目要求求任意结点之间路径的和以及路径上最大的结点,还有可能修改。如果正常做可能会很复杂(我也不知道正常应该怎么做,应该要用到LCA什么的,我还不太会)。
但是如果我们能够用线段树或者树状数组维护这个树,那么这种问题就会变得很简单。树链剖分就是这样一种将树映射在一个数组上变成线性结构然后用线段树进行维护的数据结构。
【基础知识】
- 重儿子:儿子中子树结点数目最多的那个儿子(size最大)
- 重边:父亲结点和重儿子连成的边
- 重链:由多条重边连接而成的路径
- 轻儿子:除了重儿子的其他儿子
- 轻边:父亲和轻儿子连成的边
如图所示,红圈的表示重儿子,黑边表示重边。由黑边组成的链为重链。
【具体实现】
我们先进行一次遍历得到重儿子以及深度等信息储存起来
void dfs1(int u,int f)
{
int i,v;
siz[u]=1; //储存该结点子树的大小(最小只有自身一个结点)
son[u]=0; //储存重儿子
fa[u]=f; //储存父节点
h[u]=h[f]+1;//储存深度
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f)
{
dfs1(v,u); //深度优先遍历
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
}
得到以上数据后,我们可以按重链将树映射在一个数组上。从根节点开始,优先将重链映射到数组上,然后按照深度依次进行轻儿子,轻儿子又是某一个重链的开始(每一个节点都处于一个且仅有一个重链中)。记录每条每个节点所属重链的开头(从而判断两个节点是否在同一个重链上)。
void dfs2(int u,int f,int k)
{
int i,v;
top[u]=k; //记录所属重链的开头
pos[u]=++cnt;//映射到数组上的下标(同一个重链的下标是连续的)
A[cnt]=val[u];//确定数组所对应节点的值方便进行维护
if(son[u]) dfs2(son[u],u,k);//优先遍历重儿子,从而得到连续的重链
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f&&v!=son[u]) dfs2(v,u,v); //遍历其他轻儿子
}
}
成功将树映射到数组上以后我们再用线段树对数组进行维护。对于线段树的维护是常规操作。
void update(int k,int l,int r,int x,int v)
{
if(l==r)
{
Sum[k]=Max[k]=v;
return;
}
int mid=(l+r)/2;
if(x<=mid) update(k<<1,l,mid,x,v);
else update(k<<1|1,mid+1,r,x,v);
Sum[k]=Sum[k<<1]+Sum[k<<1|1];
Max[k]=max(Max[k<<1],Max[k<<1|1]);
}
int QuerySum(int k,int l,int r,int L,int R)
{
if(L<=l && r<=R) return Sum[k];
int mid=(l+r)/2;
int ret=0;
if(L<=mid) ret+=QuerySum(k<<1,l,mid,L,R);
if(R>mid) ret+=QuerySum(k<<1|1,mid+1,r,L,R);
return ret;
}
int QueryMax(int k,int l,int r,int L,int R)
{
if(L==l && r==R) return Max[k];
int mid=(l+r)/2;
if(R<=mid) return QueryMax(k<<1,l,mid,L,R);
else if(L>mid) return QueryMax(k<<1|1,mid+1,r,L,R);
else return max(QueryMax(k<<1,l,mid,L,mid),QueryMax(k<<1|1,mid+1,r,mid+1,R));
}
重点还是对于树上两个点如何得到他们之间的一条路径以及这个路径在映射数组中的位置。我们每次从深度更深的点向上升,直到两个节点处在同一条链中(或者处于同一节点处)。在上升的过程中记录每条链的值(每条链都处于映射数组的一个连续的区间内)
int FindSum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans+=QuerySum(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans+=QuerySum(1,1,n,pos[v],pos[u]);
return ans;
}
int FindMax(int u,int v)
{
int ans=INT_MIN;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans=max(ans,QueryMax(1,1,n,pos[top[u]],pos[u]));
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans=max(ans,QueryMax(1,1,n,pos[v],pos[u]));
return ans;
}
这样我们就成功做到用线段树维护树状结构的数据啦
【AC代码】
#include<iostream>
#include<cstdio>
#include<vector>
#include<cmath>
#include<cstring>
#include<algorithm>
#include<climits>
using namespace std;
const int MAXN=30010;
vector<int>g[MAXN];
int fa[MAXN],A[MAXN],val[MAXN],pos[MAXN],siz[MAXN],son[MAXN],h[MAXN],top[MAXN];
int cnt=0,n,m;
int Sum[MAXN<<2],Max[MAXN<<2];
void dfs1(int u,int f)
{
int i,v;
siz[u]=1;
son[u]=0;
fa[u]=f;
h[u]=h[f]+1;
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f)
{
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
}
void dfs2(int u,int f,int k)
{
int i,v;
top[u]=k;
pos[u]=++cnt;
A[cnt]=val[u];
if(son[u]) dfs2(son[u],u,k);
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f&&v!=son[u]) dfs2(v,u,v);
}
}
void update(int k,int l,int r,int x,int v)
{
if(l==r)
{
Sum[k]=Max[k]=v;
return;
}
int mid=(l+r)/2;
if(x<=mid) update(k<<1,l,mid,x,v);
else update(k<<1|1,mid+1,r,x,v);
Sum[k]=Sum[k<<1]+Sum[k<<1|1];
Max[k]=max(Max[k<<1],Max[k<<1|1]);
}
int QuerySum(int k,int l,int r,int L,int R)
{
if(L<=l && r<=R) return Sum[k];
int mid=(l+r)/2;
int ret=0;
if(L<=mid) ret+=QuerySum(k<<1,l,mid,L,R);
if(R>mid) ret+=QuerySum(k<<1|1,mid+1,r,L,R);
return ret;
}
int QueryMax(int k,int l,int r,int L,int R)
{
if(L==l && r==R) return Max[k];
int mid=(l+r)/2;
if(R<=mid) return QueryMax(k<<1,l,mid,L,R);
else if(L>mid) return QueryMax(k<<1|1,mid+1,r,L,R);
else return max(QueryMax(k<<1,l,mid,L,mid),QueryMax(k<<1|1,mid+1,r,mid+1,R));
}
int FindSum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans+=QuerySum(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans+=QuerySum(1,1,n,pos[v],pos[u]);
return ans;
}
int FindMax(int u,int v)
{
int ans=INT_MIN;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans=max(ans,QueryMax(1,1,n,pos[top[u]],pos[u]));
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans=max(ans,QueryMax(1,1,n,pos[v],pos[u]));
return ans;
}
int main()
{
int a,b,i;
char s[10];
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
for(i=1;i<=n;i++) scanf("%d",&val[i]);
dfs1(1,0);
dfs2(1,0,1);
for(i=1;i<=n;i++) update(1,1,n,i,A[i]);
scanf("%d",&m);
while(m--)
{
scanf("%s%d%d",s,&a,&b);
if(s[1]=='H') update(1,1,n,pos[a],b);
else if(s[1]=='S') printf("%d\n",FindSum(a,b));
else printf("%d\n",FindMax(a,b));
}
return 0;
}