树链剖分算是一个应用比较广泛而且比较好实现的一种方法,其大体思想主要是把树链分成轻链和重链,这样既可以套用数据结构也可以求LCA。
首先定义 重儿子和轻儿子,对于一个非叶子节点,它的所有儿子中子树大小最大的即为重儿子,其余的为轻儿子。
我们可以用一次dfs来求出所有的重儿子,重儿子连成的树链即为重链,其余的为轻链,可以证明所有的重链不超过logn条。
对于每个节点还要维护其深度大小…之后需要…
求重儿子代码
void dfs1(int u,int deep,int fath)
{
dep[u]=deep;siz[u]=1;fa[u]=fath;
for(int i=0;i<f[u].size();i++)
{
int v=f[u][i];
if(v==fath) continue;
dfs1(v,deep+1,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
之后将重儿子连成重链,top数组记录着当前结点所在重链的顶端结点。
并将节点编号便于在线段树上维护。
void dfs2(int u,int tp)
{
top[u]=tp;
id[u]=++cnt;
dfth[cnt]=u;
if(son[u]) dfs2(son[u],tp);
for(int i=0;i<f[u].size();i++)
{
int v=f[u][i];
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
这样连下来,每个重链的编号都是连续的,可以用线段树等数据结构维护。
下为BZOJ 1036的代码
#include<cstdio>
#include<iostream>
#include<vector>
using namespace std;
const int maxn=2011010;
vector<int> f[maxn];
int val[maxn],dep[maxn],siz[maxn],son[maxn],id[maxn],top[maxn],fa[maxn];
int x,y,cnt;
int T_sum[maxn<<2],T_max[maxn<<2],dfth[maxn];
int query_sum(int rt,int l,int r,int x,int y)
{
if(x<=l&&y>=r)
{
return T_sum[rt];
}
int mid=(l+r)>>1;
if(y<=mid) return query_sum(rt<<1,l,mid,x,y);
if(x>mid) return query_sum(rt<<1|1,mid+1,r,x,y);
return query_sum(rt<<1,l,mid,x,y)+query_sum(rt<<1|1,mid+1,r,x,y);
}
int query_max(int rt,int l,int r,int x,int y)
{
if(x<=l&&y>=r)
{
return T_max[rt];
}
int mid=(l+r)>>1;
if(y<=mid) return query_max(rt<<1,l,mid,x,y);
if(x>mid) return query_max(rt<<1|1,mid+1,r,x,y);
return max(query_max(rt<<1,l,mid,x,y),query_max(rt<<1|1,mid+1,r,x,y));
}
void dfs1(int u,int deep,int fath)
{
dep[u]=deep;
siz[u]=1;
fa[u]=fath;
for(int i=0;i<f[u].size();i++)
{
int v=f[u][i];
if(v==fath) continue;
dfs1(v,deep+1,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
void dfs2(int u,int tp)
{
top[u]=tp;
id[u]=++cnt;
dfth[cnt]=u;
if(son[u]) dfs2(son[u],tp);
for(int i=0;i<f[u].size();i++)
{
int v=f[u][i];
if(v==fa[u]||v==son[u]) continue;
dfs2(v,v);
}
}
int get_sum(int u,int v)
{
int sum=0;
int x=top[u],y=top[v];
while(x!=y)
{
if(dep[x]<dep[y])
{
swap(u,v);swap(x,y);
}
sum+=query_sum(1,1,cnt,id[x],id[u]);
u=fa[x];x=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
return sum+query_sum(1,1,cnt,id[u],id[v]);
}
int get_max(int u,int v)
{
int ans=-200000000;
int x=top[u],y=top[v];
while(x!=y)
{
if(dep[x]<dep[y])
{
swap(u,v);swap(x,y);
}
ans=max(ans,query_max(1,1,cnt,id[x],id[u]));
u=fa[x];x=top[u];
}
if(dep[u]>dep[v]) swap(u,v);
return max(ans,query_max(1,1,cnt,id[u],id[v]));
}
void build(int rt,int l,int r)
{
if(l==r)
{
T_max[rt]=T_sum[rt]=val[dfth[l]];
return;
}
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
T_max[rt]=max(T_max[rt<<1],T_max[rt<<1|1]);
T_sum[rt]=T_sum[rt<<1]+T_sum[rt<<1|1];
}
void query_change(int rt,int l,int r,int x,int val)
{
if(l==r)
{
T_max[rt]=T_sum[rt]=val;
return;
}
int mid=(l+r)>>1;
if(x<=mid) query_change(rt<<1,l,mid,x,val);
else query_change(rt<<1|1,mid+1,r,x,val);
T_max[rt]=max(T_max[rt<<1],T_max[rt<<1|1]);
T_sum[rt]=T_sum[rt<<1]+T_sum[rt<<1|1];
}
int main()
{
int n;
scanf("%d",&n);
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
f[x].push_back(y);
f[y].push_back(x);
}
for(int i=1;i<=n;i++) scanf("%d",&val[i]);
dfs1(1,0,1);
dfs2(1,1);
build(1,1,cnt);
int z;char c[10];
scanf("%d",&z);
for(int i=1;i<=z;i++)
{
scanf("%s",c);
if(c[0]=='C')
{
scanf("%d%d",&x,&y);
query_change(1,1,cnt,id[x],y);
}
if(c[0]=='Q')
{
if(c[1]=='S')
{
scanf("%d%d",&x,&y);
printf("%d\n",get_sum(x,y));
}
if(c[1]=='M')
{
scanf("%d%d",&x,&y);
printf("%d\n",get_max(x,y));
}
}
}
return 0;
}
不喜欢vector的同学可以选择邻接表QAQ