好像是第一篇博客呢ww
**
引入
1、作用:在一棵树上进行路径的修改、求极值、求和
2、树链:树上的路径;剖分:把路径分成轻链和重链。
**
定义
**
1、deep[u]——u的深度
fa[u]——u的父亲
son[u]——u的重儿子
top[u]——u所在重链的顶端
w[u]——u在线段树中的编号
2、重儿子:u的子节点中size最大的(若有多个相同,任取一个)
轻儿子:u除重儿子外的其他节点
重边:u与son[u]的连边
轻边:u与其轻儿子的连边
重链:重边组成的路径
▲处理:dfs(以时间戳表示w数组)
▲对于标号w:同一子树的标号连续且根最小,同一重链标号连续且随deep增大而增大(证明:括号化定理)
性质
1、对于u的任意轻儿子v,则size[v]*2< size[u](易证)
2、从根到某一点的路径上轻边、重链的个数都不大于logn。
证明:
a.只有轻边时,每经过一条边e(u,v),size[v]*2 < size[u],又size[root]=n,size[ui]>=1,所以轻边数<=logn
b.连续的重边形成重链,所以重链数<=轻边数+1
实现
1、dfs1:求出fa,son,deep,size
2、dfs2:求出w,top(先搜u的重儿子——定义中w的性质2)
3、change:
a.top[x]=x则f1=fa[x],否则f1=top[x],f2处理方式相同
b.若top[x]=top[y],则x,y在同一重链上,直接修改w[x]——>w[y]
c.若deep[f1]>=deep[f2](反之处理方式相同,这里略),若top[x]=x,修改w[x],否则修改w[f1]+1——>w[x]; x=f1;
(这样可以防止x,y重叠或越界)
4、查询类似于3,略
例:修改11——>10
1、u=11, v=10, f1=2, f2=4; deep[f1]>=deep[f2],修改6——>11,u=2;
2、u=2, v=10, f1=1, f2=4; deep[f2]>=deep[f1], 修改10,v=4
3、u=2,v=4,top[u]=top[v],结束
例题
题目(bzoj突然进不去了。。。yyhs网络太渣,vjudge上的题)
bzoj1036,树的统计 Count,树链剖分模板题
#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
int n,num=0,edgenum=0,top[120005],maxx[120005],sum[120005],fa[120005],w[120005],vet[60005],next[60005],head[120005],size[120005],son[120005],deep[30005];
void add(int u,int v)
{
num++;
vet[num]=v;
next[num]=head[u];
head[u]=num;
}
void dfs1(int u,int father)
{
fa[u]=father;
son[u]=0;
int mx=0;
for (int i=head[u]; i; i=next[i])
{
int v=vet[i];
if (v==father) continue;
deep[v]=deep[u]+1;
dfs1(v,u);
size[u]+=size[v];
if (size[v]>mx)
{
mx=size[v];
son[u]=v;
}
}
}
void dfs2(int u,int t)
{
w[u]=++edgenum;
top[u]=t;
if (son[u]) dfs2(son[u],t);
for (int i=head[u]; i; i=next[i])
{
int v=vet[i];
if (v!=fa[u] && v!=son[u]) dfs2(v,v);
}
}
void insert(int u,int l,int r,int x,int val)
{
if (l==r)
{
maxx[u]=sum[u]=val;
return;
}
int mid=(l+r)>>1;
if (x<=mid) insert(u+u,l,mid,x,val);
else insert(u+u+1,mid+1,r,x,val);
sum[u]=sum[u+u]+sum[u+u+1];
maxx[u]=max(maxx[u+u],maxx[u+u+1]);
}
int cl1(int x,int y,int t)
{
if (t==1) return max(x,y);
else return x+y;
}
int find(int u,int l,int r,int x,int y,int t)
{
if (l==x && r==y)
{
if (t==1) return maxx[u];
else return sum[u];
}
int mid=(l+r)>>1;
if (y<=mid) return find(u+u,l,mid,x,y,t);
else
if (x>mid) return find(u+u+1,mid+1,r,x,y,t);
else return cl1(find(u+u,l,mid,x,mid,t),find(u+u+1,mid+1,r,mid+1,y,t),t);
}
int cl(int x,int y,int t)
{
int ans=0,f1,f2;
if (t==1) ans=-1e9;
while (top[x]!=top[y])
{
if (top[x]==x) f1=fa[x];
else f1=top[x];
if (top[y]==y) f2=fa[y];
else f2=top[y];
if (deep[f1]>=deep[f2])
{
if (top[x]==x) ans=cl1(ans,find(1,1,n,w[x],w[x],t),t);
else ans=cl1(ans,find(1,1,n,w[f1]+1,w[x],t),t);
x=f1;
}
else
{
if (top[y]==y) ans=cl1(ans,find(1,1,n,w[y],w[y],t),t);
else ans=cl1(ans,find(1,1,n,w[f2]+1,w[y],t),t);
y=f2;
}
}
ans=cl1(ans,find(1,1,n,min(w[x],w[y]),max(w[x],w[y]),t),t);
return ans;
}
int main()
{
scanf("%d",&n);
for (int i=1; i<n; i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
deep[1]=1;
for (int i=1; i<=n; i++) size[i]=1;
for (int i=1; i<=n*4; i++) maxx[i]=-1e9;
dfs1(1,0);
dfs2(1,1);
for (int i=1; i<=n; i++)
{
int x;
scanf("%d",&x);
insert(1,1,n,w[i],x);
}
int m;
scanf("%d",&m);
while (m--)
{
char st[10];
scanf("%s",st);
int x,y;
scanf("%d%d",&x,&y);
if (st[0]=='C') insert(1,1,n,w[x],y);
if (st[1]=='M') printf("%d\n",cl(x,y,1));
if (st[1]=='S') printf("%d\n",cl(x,y,2));
}
}
emm。。。终于写完了,看区间操作(lang)去了。。。
参考:http://blog.csdn.net/jiangshibiao/article/details/24669751