【算法介绍】
树链剖分就是将树分割成多条链,然后利用数据结构(线段树、树状数组等)来维护这些链。
当树的形态不发生改变的时候,我们可以先对其进行链的剖分,每条链就相当于一个序列,操作就可以被拆分成几条完整的链来解决,然后利用一些数据结构加以维护即可。
我们希望通过这样的方式,来达到一些树上修改、计算的目的
【算法流程】
轻重链剖分
概念引入:把一个节点u的所有儿子中size[v]最大的一个作为重儿子,则称(u,v)为重边,u到其余的儿子v'的边(u,v')为轻边。
性质:(u,v)为轻边,那么size(v)<=size(u)/2 => 从根到某一点v的轻边数量不超过v
这两点性质保证了树剖的时间复杂度为O(nlogn)
具体的实现方法:
1.第一次dfs算出father[x],dep[x],size[x],son[x]
void dfs1(int u,int f)
{
siz[u]=1; fa[u]=f; dep[u]=dep[f]+1;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==f) continue;
dfs1(to,u);
siz[u]+=siz[to];
if(siz[to]>siz[son[u]]) son[u]=to;
}
}
2.第二次dfs算出top[x](所在重路径的顶部节点),seg[x](x在线段树中的位置),rev[x](seg的反数组,rev[seg[x]]=x)
注意:为了保证时间效率,要把重链按顺序排在连续的编号上
void dfs2(int u,int f)
{
if(son[u])
{
seg[son[u]]=++seg[0];
top[son[u]]=top[u];
rev[seg[0]]=son[u];
dfs2(son[u],u);
}
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(!top[to])
{
seg[to]=++seg[0];
top[to]=to;
rev[seg[0]]=to;
dfs2(to,u);
}
}
}
3.查找路径(u,v)信息,我们可以把任意路径拆分成若干重路径,也就是若干区间,并用线段树处理操作
其实本质就是用top加速求LCA
对于u,v,深度较大的一个点向上跳到fa[top[u]]处,处理知道它们出现在同一个重链上
int sum(int x,int y)
{
int ans=0,fx=top[x],fy=top[y];
while(fx!=fy) //两点不在同一条重链
{
if(d[fx]>=d[fy])
{
ans+=query(id[fx],id[x],rt); //线段树区间求和,处理这条重链的贡献
x=f[fx],fx=top[x]; //将x设置成原链头的父亲结点,走轻边,继续循环
}
else
{
ans+=query(id[fy],id[y],rt);
y=f[fy],fy=top[y];
}
}
//循环结束,两点位于同一重链上,但两点不一定为同一点,所以我们还要统计这两点之间的贡献
if(id[x]<=id[y])
ans+=query(id[x],id[y],rt);
else
ans+=query(id[y],id[x],rt);
return ans;
}
【习题】
1.[SDOI2011]染色
【题意】
给定一棵 n个节点的无根树,共有 m 个操作,操作分为两种:
- 将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。
- 询问节点 a 到节点 b 的路径上的颜色段数量。
颜色段的定义是极长的连续相同颜色被认为是一段。例如 112221
由三段组成:11
、222
、1
。
【分析】
很显然,这个树上的操作我们可以用树链剖分来优化,这样呢,我们需要做的便是一段链的颜色块数统计
这个是线段树较为经典的操作了,记录lc,rc为这个区间最左侧的颜色和最右侧的颜色
在pushup和统计数量的时候如果lson的rc和rson的lc相等,那么二者合并的颜色块数量-1
其他的情况就是普通的处理
注意细节:在处理路径的时候,我们记录一个pos1表示当前要往上跳的一支,上次处理的最后一个颜色为pos1,如果在更新完当前这枝后,lc==pos1那么要res-1(因为二者可以合并成为一个颜色块)
错误:①在交换x,y的同时,也要交换pos1和pos2,因为这样才是把两支完全交换 ②dfs2时候,走轻边的时候fa要写成to,表示top[x]=x,自己是独立的重链
③dfs2判断当前枚举的下一个点不是父亲节点时候,要注意不是用fa,而是f[x],因为fa记录的是重链子的最高点
【代码】
#include<bits/stdc++.h>
using namespace std;
int n,m,lc,rc;
const int maxn=1e5+5;
int head[maxn],tot,col[maxn];
struct edge
{
int to,nxt;
}e[maxn<<1];
void add(int x,int y)
{
e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot;
}
int dep[maxn],size[maxn],son[maxn],f[maxn],top[maxn],seg[maxn],rev[maxn];
int dfs_time;
void dfs1(int x,int fa)
{
f[x]=fa; dep[x]=dep[fa]+1; size[x]=1;
for(int i=head[x];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==fa) continue;
dfs1(to,x);
size[x]+=size[to];
if(size[to]>size[son[x]])
son[x]=to;
}
}
void dfs2(int x,int fa)
{
top[x]=fa; seg[x]=++dfs_time; rev[dfs_time]=x;
if(!son[x]) return;
dfs2(son[x],fa);
for(int i=head[x];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==son[x] || to==f[x]) continue; //RE2
dfs2(to,to); //WA1
}
}
struct segtree
{
int l,r,lc,rc,val;
int flag;
}tr[maxn<<2];
#define lson now<<1
#define rson now<<1|1
void pushup(int now)
{
tr[now].val=tr[lson].val+tr[rson].val;
if(tr[lson].rc==tr[rson].lc) --tr[now].val;
tr[now].lc=tr[lson].lc;
tr[now].rc=tr[rson].rc;
}
void build(int now,int l,int r)
{
tr[now].l=l; tr[now].r=r;
if(l==r)
{
tr[now].lc=tr[now].rc=col[rev[l]];
tr[now].val=1;
return;
}
int mid=l+r>>1;
build(lson,l,mid);
build(rson,mid+1,r);
pushup(now);
}
void pushcol(int now,int x)
{
tr[now].lc=x; tr[now].rc=x;
tr[now].val=1; tr[now].flag=x;
}
void pushdown(int now)
{
if(!tr[now].flag) return;
if(lson) pushcol(lson,tr[now].flag);
if(rson) pushcol(rson,tr[now].flag);
tr[now].flag=0;
}
int query(int now,int L,int R)
{
int l=tr[now].l,r=tr[now].r;
if(L<=l && R>=r)
{
if(l==L) lc=tr[now].lc;
if(r==R) rc=tr[now].rc;
return tr[now].val;
}
pushdown(now);
int mid=l+r>>1;
if(R<=mid) return query(lson,L,R);
if(L>mid) return query(rson,L,R);
int res=query(lson,L,R)+query(rson,L,R);
if(tr[lson].rc==tr[rson].lc) --res;
return res;
}
int ask(int x,int y)
{
int res=0;
int pos1=0,pos2=0;
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y),swap(pos1,pos2);
res+=query(1,seg[top[x]],seg[x]);
if(rc==pos1) --res;
pos1=lc; x=f[top[x]];
}
if(seg[x]>seg[y]) swap(x,y),swap(pos1,pos2);
res+=query(1,seg[x],seg[y]);
if(lc==pos1) --res;
if(rc==pos2) --res;
return res;
}
void modify(int now,int L,int R,int x)
{
int l=tr[now].l,r=tr[now].r;
if(L<=l && R>=r)
{
pushcol(now,x);
return;
}
pushdown(now);
int mid=l+r>>1;
if(L<=mid) modify(lson,L,R,x);
if(R>mid) modify(rson,L,R,x);
pushup(now);
}
void updatecol(int x,int y,int color)
{
while(top[x]!=top[y])
{
if(dep[top[x]]<dep[top[y]]) swap(x,y);
modify(1,seg[top[x]],seg[x],color);
x=f[top[x]];
}
if(seg[x]>seg[y]) swap(x,y);
modify(1,seg[x],seg[y],color);
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) scanf("%d",&col[i]);
int x,y,z;
for(int i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y); add(y,x);
}
dfs1(1,0); dfs2(1,1);
build(1,1,dfs_time);
for(int i=1;i<=m;i++)
{
char s[5];
scanf("%s",s);
if(s[0]=='Q')
{
scanf("%d%d",&x,&y);
printf("%d\n",ask(x,y));
}
else
{
scanf("%d%d%d",&x,&y,&z);
updatecol(x,y,z);
}
}
return 0;
}
2.[SDOI2014]旅行
【题意】
树上两个点,旅行时走最短路,每个点都有一个宗教值和评级,有以下四种操作:①修改某个点的宗教值 ②修改某个点的评级 ③查询路径上宗教为c的评级和 ④查询路径上宗教为c得到评级最大值
【分析】
很明显如果没有宗教的限制,对于树上的操作我们要用树链剖分和线段树来解决,由于加入了一个宗教且(),我们可以建立c棵线段树分别记录即可
注意:要动态开点,否则会MLE!对于修改操作,把原来所在线段树上的对应值修改成0即可
时空复杂度:
【代码】
待补
3.P4211 [LNOI2014]LCA
【题意】
q次询问区间l-r中每个点到z的lca的深度和
【分析】
很显然,每次询问固定的是z,把l-r到根的路径上每个边+1,然后对于z跑一次公共路径长度即可
所以我们需要先把l,r进行离线排序,再每次进行修改
【代码】
#include<bits/stdc++.h>
using namespace std;
int n,q;
const int maxn=4e5+5;
const int mod=201314;
int head[maxn],tot,cnt;
struct edge
{
int to,nxt;
}e[maxn];
void add(int x,int y)
{
e[++tot].to=y; e[tot].nxt=head[x]; head[x]=tot;
}
struct query
{
int p,no,flag;
}a[maxn<<1];
struct answer
{
int z,ans1,ans2;
}qq[maxn];
struct segtree
{
int lz,l,r,sum;
}tr[maxn<<2];
bool cmp(query a,query b)
{
return a.p<b.p;
}
int f[maxn],dep[maxn],son[maxn],siz[maxn],top[maxn],dfn[maxn],rev[maxn];
void dfs1(int u,int fa)
{
f[u]=fa; dep[u]=dep[fa]+1; siz[u]=1;
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==fa) continue;
dfs1(to,u);
siz[u]+=siz[to];
if(siz[son[u]]<siz[to]) son[u]=to;
}
}
int dfstime;
void dfs2(int u,int fa)
{
top[u]=fa; dfn[u]=++dfstime; rev[dfstime]=u;
if(!son[u]) return;
dfs2(son[u],fa);
for(int i=head[u];i;i=e[i].nxt)
{
int to=e[i].to;
if(to==son[u] || to==f[u]) continue;
dfs2(to,to);
}
}
#define lson now<<1
#define rson now<<1|1
void build(int now,int l,int r)
{
tr[now].l=l,tr[now].r=r;
if(l==r)
{
tr[now].sum=0;
return;
}
int mid=l+r>>1;
build(lson,l,mid);
build(rson,mid+1,r);
}
void pushup(int now)
{
tr[now].sum=tr[lson].sum+tr[rson].sum;
}
void pushdown(int now)
{
if(tr[now].l==tr[now].r || !tr[now].lz) return;
tr[lson].sum+=tr[now].lz*(tr[lson].r-tr[lson].l+1);
tr[lson].lz+=tr[now].lz;
tr[rson].sum+=tr[now].lz*(tr[rson].r-tr[rson].l+1);
tr[rson].lz+=tr[now].lz;
tr[now].lz=0;
}
void update(int now,int l,int r)
{
pushdown(now);
if(tr[now].l==l && tr[now].r==r)
{
tr[now].lz++;
tr[now].sum+=tr[now].r-tr[now].l+1;
return;
}
int mid=tr[now].l+tr[now].r>>1;
if(r<=mid)update(lson,l,r);
else if(l>mid)update(rson,l,r);
else
{
update(lson,l,mid);
update(rson,mid+1,r);
}
pushup(now);
}
void modify(int x,int y)
{
while(top[x]!=top[y])
{
update(1,dfn[top[x]],dfn[x]);
x=f[top[x]];
}
update(1,dfn[y],dfn[x]);
}
int calc(int now,int l,int r)
{
pushdown(now);
if(tr[now].l==l && tr[now].r==r) return tr[now].sum;
int mid=tr[now].l+tr[now].r>>1;
int res=0;
if(r<=mid) return calc(lson,l,r);
else if(l>mid) return calc(rson,l,r);
else return calc(lson,l,mid)+calc(rson,mid+1,r);
}
int query(int x,int y)
{
int res=0;
while(top[x]!=top[y])
{
res+=calc(1,dfn[top[x]],dfn[x]);
res%=mod;
x=f[top[x]];
}
res+=calc(1,dfn[y],dfn[x]); res%=mod;
return res;
}
int main()
{
freopen("lca.in","r",stdin);
freopen("lca.out","w",stdout);
scanf("%d%d",&n,&q);
int x;
for(int i=2;i<=n;i++)
{
scanf("%d",&x); x++;
add(x,i);
}
int l,r,z;
for(int i=1;i<=q;i++)
{
scanf("%d%d%d",&l,&r,&z);
l++; r++; z++;
qq[i].z=z;
a[++cnt].p=l-1; a[cnt].no=i; a[cnt].flag=0;
a[++cnt].p=r; a[cnt].no=i; a[cnt].flag=1;
}
sort(a+1,a+cnt+1,cmp);
dfs1(1,0);
dfs2(1,1);
build(1,1,n);
int now=0;
for(int i=1;i<=cnt;i++)
{
while(now<a[i].p)
{
now++;
modify(now,1);
}
int t=a[i].no;
if(!a[i].flag) qq[t].ans1=query(qq[t].z,1);
else qq[t].ans2=query(qq[t].z,1);
}
for(int i=1;i<=q;i++)
printf("%lld\n",(qq[i].ans2-qq[i].ans1+mod)%mod);
return 0;
}