引言
如果给你一串数,查询区间的和(max或min),该如何做?
线段树(树状数组)
但如果这个问题换到树上,查询 u,v 路径上的和呢?
树LCA+暴力
如果数据随机,这种方法是可以过的!
可是你觉得出题人会这么好心吗?
所以我们引出树链剖分,将这类问题转化成由线段树可以解决的问题!
定义
deep[i]
d
e
e
p
[
i
]
为i节点的深度(根的深度为1)
son[i]
s
o
n
[
i
]
为i的重儿子
fa[i]
f
a
[
i
]
为i的父节点
size[i]
s
i
z
e
[
i
]
为以i为根节点的子树大小
top[i]
t
o
p
[
i
]
为i所在的重链的顶节点
重儿子:子节点中size值最大的为重儿子,其他则为轻儿子。
重边:连接该节点与它的重儿子的边。
重链:由一系列重边相连得到的链。
轻链:由一系列非重边相连得到的链。
上图解释一下。
1的重儿子为4,4的重儿子为5,5的重儿子为6.
2的重儿子为3
那么1->4->5->6为一条重链
2->3为一条重链
7为一条重链
图中的重链涵盖了所以的树上节点,重链之间由轻边连接。
原理
若我们对上图由重儿子优先的原理进行dfs编号,会得到下图
橙色数字为新编号
重链上的序号都是连续的。
我们可以通过这张重建的新图建立线段树。
实现
先通过dfs1得到每个节点的size,fa,son,deep
int dfs1(int now,int fax,int dep)
{
fa[now]=fax,deep[now]=dep,size[now]=1;
int maxson=-1;
for(int i=head[now];i;i=net[i])
if(to[i]!=fax)
{
size[now]+=dfs1(to[i],now,dep+1);
if(maxson<size[to[i]])
son[now]=to[i],maxson=size[to[i]];
}
return size[now];
}
通过dfs2得到top以及进行重新编号
void dfs2(int now,int topx)
{
id[now]=++tot;
val[tot]=p[now];
top[now]=topx;
if(!son[now]) return;
dfs2(son[now],topx);
for(int i=head[now];i;i=net[i])
if(!id[to[i]])
dfs2(to[i],to[i]);
}
线段树的基本操作
- 建树
void build(int o,int l,int r)
{
t[o].l=l,t[o].r=r,t[o].siz=r-l+1;
if(l==r)
{
t[o].sum=val[l];
return;
}
int mid=(l+r)>>1;
build(lson),build(rson);
update(o);
}
- 区间求和
int ask(int o,int ql,int qr)
{
int l=t[o].l,r=t[o].r;
if(ql<=l&&qr>=r)
return t[o].sum%mod;
pushdown(o);
int mid=(l+r)>>1;
int p1=0,p2=0;
if(ql<=mid) p1=ask((o<<1),ql,qr);
if(qr>mid) p2=ask((o<<1)|1,ql,qr);
return (p1+p2)%mod;
}
- 区间修改
void adj(int o,int ql,int qr,int num)
{
int l=t[o].l,r=t[o].r;
if(ql<=l&&qr>=r)
{
t[o].sum=(t[o].sum+(t[o].siz*num)%mod)%mod;
t[o].add=(t[o].add+num)%mod;
return;
}
pushdown(o);
int mid=(l+r)>>1;
if(ql<=mid) adj((o<<1),ql,qr,num);
if(qr>mid) adj((o<<1)|1,ql,qr,num);
update(o);
}
查询方式
树剖和线段树都搞好以后,如何解决u,v之间路径和的问题?
假设deep[u]>deep[v]
1:若二者在一条重链上
由于重链中序号是连续的
所以
Ans=ask(1,id[v],id[u])
A
n
s
=
a
s
k
(
1
,
i
d
[
v
]
,
i
d
[
u
]
)
2:二者在两条重链上
假设u所在的重链顶点deep比较大
我们可以先查询
ask(1,id[top[u]],id[u])
a
s
k
(
1
,
i
d
[
t
o
p
[
u
]
]
,
i
d
[
u
]
)
然后令u=fa[top[u]]
如此继续下去,就可以把两者转化至一条重链上,然后沿用情况1的方法
在这里可以看出,树剖可以求LCA!
至于修改操作,跟上面的大同小异。
u,v查询
inline int tree_sum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
ans=(ans+ask(1,id[top[u]],id[u])%mod)%mod;
u=fa[top[u]];
}
if(deep[u]<deep[v]) swap(u,v);
ans=(ans+ask(1,id[v],id[u])%mod)%mod;
return ans%mod;
}
u,v修改
inline void tree_x(int u,int v,int num)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
adj(1,id[top[u]],id[u],num);
u=fa[top[u]];
}
if(deep[u]<deep[v]) swap(u,v);
adj(1,id[v],id[u],num);
}
其他问题
1:查询以u为根节点的子树和
Ans=ask(1,id[u],id[u]+size[u]−1)
A
n
s
=
a
s
k
(
1
,
i
d
[
u
]
,
i
d
[
u
]
+
s
i
z
e
[
u
]
−
1
)
(编号连续)
2:修改以u为根节点的子树
change(1,id[u],id[u]+size[u]−1)
c
h
a
n
g
e
(
1
,
i
d
[
u
]
,
i
d
[
u
]
+
s
i
z
e
[
u
]
−
1
)
时间复杂度分析
性质1
如果边(u,v)为轻边,那么Size(v)≤Size(u)/2。
证明:反之,则为重边
性质2
树中任意两个节点之间的路径中轻边的条数不会超过
log2n
l
o
g
2
n
,重路径的数目不会超过
log2n
l
o
g
2
n
证明:无!
分析:
由于重路径的数量的上界为
log2n
l
o
g
2
n
线段树中查询/修改的复杂度为
log2n
l
o
g
2
n
那么总的时间复杂度为
(log2n)2
(
l
o
g
2
n
)
2
代码
#include <cstdio>
#include <iostream>
#define lson (o<<1),l,mid
#define rson (o<<1)+1,mid+1,r
using namespace std;
const int maxm=1e5+1;
int deep[maxm],son[maxm],size[maxm],fa[maxm],top[maxm];
int head[maxm],to[maxm<<1],net[maxm<<1],cnt;
int p[maxm],val[maxm];
int id[maxm],tot;
int n,m,root,mod;
struct node{
int l,r,siz,add,sum;
};
node t[maxm*4];
inline void add(int x,int y)
{
to[++cnt]=y;
net[cnt]=head[x];
head[x]=cnt;
}
int dfs1(int now,int fax,int dep)
{
fa[now]=fax,deep[now]=dep,size[now]=1;
int maxson=-1;
for(int i=head[now];i;i=net[i])
if(to[i]!=fax)
{
size[now]+=dfs1(to[i],now,dep+1);
if(maxson<size[to[i]])
son[now]=to[i],maxson=size[to[i]];
}
return size[now];
}
void dfs2(int now,int topx)
{
id[now]=++tot;
val[tot]=p[now];
top[now]=topx;
if(!son[now]) return;
dfs2(son[now],topx);
for(int i=head[now];i;i=net[i])
if(!id[to[i]])
dfs2(to[i],to[i]);
}
inline void update(int o)
{
t[o].sum=(t[(o<<1)].sum%mod+t[(o<<1)|1].sum%mod)%mod;
}
inline void pushdown(int o)
{
int adi=t[o].add%mod;
for(int i=0;i<=1;i++)
t[(o<<1)+i].sum=(t[(o<<1)+i].sum+(t[(o<<1)+i].siz*adi)%mod)%mod,t[(o<<1)+i].add=(t[(o<<1)+i].add+adi)%mod;
t[o].add=0;
}
void adj(int o,int ql,int qr,int num)
{
int l=t[o].l,r=t[o].r;
if(ql<=l&&qr>=r)
{
t[o].sum=(t[o].sum+(t[o].siz*num)%mod)%mod;
t[o].add=(t[o].add+num)%mod;
return;
}
pushdown(o);
int mid=(l+r)>>1;
if(ql<=mid) adj((o<<1),ql,qr,num);
if(qr>mid) adj((o<<1)|1,ql,qr,num);
update(o);
}
int ask(int o,int ql,int qr)
{
int l=t[o].l,r=t[o].r;
if(ql<=l&&qr>=r)
return t[o].sum%mod;
pushdown(o);
int mid=(l+r)>>1;
int p1=0,p2=0;
if(ql<=mid) p1=ask((o<<1),ql,qr);
if(qr>mid) p2=ask((o<<1)|1,ql,qr);
return (p1+p2)%mod;
}
void build(int o,int l,int r)
{
t[o].l=l,t[o].r=r,t[o].siz=r-l+1;
if(l==r)
{
t[o].sum=val[l];
return;
}
int mid=(l+r)>>1;
build(lson),build(rson);
update(o);
}
inline int tree_sum(int u,int v)
{
int ans=0;
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
ans=(ans+ask(1,id[top[u]],id[u])%mod)%mod;
u=fa[top[u]];
}
if(deep[u]<deep[v]) swap(u,v);
ans=(ans+ask(1,id[v],id[u])%mod)%mod;
return ans%mod;
}
inline void tree_x(int u,int v,int num)
{
while(top[u]!=top[v])
{
if(deep[top[u]]<deep[top[v]]) swap(u,v);
adj(1,id[top[u]],id[u],num);
u=fa[top[u]];
}
if(deep[u]<deep[v]) swap(u,v);
adj(1,id[v],id[u],num);
}
inline int read()
{
int x=0;char ch=0;
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x;
}
int main()
{
n=read(),m=read(),root=read(),mod=read();
for(int i=1;i<=n;i++)
p[i]=read(),p[i]%=mod;
for(int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
add(u,v),add(v,u);
}
dfs1(root,0,1);
dfs2(root,root);
build(1,1,tot);
for(int i=1;i<=m;i++)
{
int opt,u,v,x;
opt=read();
if(opt==1)
{
u=read(),v=read(),x=read();
tree_x(u,v,x%mod);
}
if(opt==2)
{
u=read(),v=read();
printf("%d\n",tree_sum(u,v));
}
if(opt==3)
{
u=read(),x=read();
adj(1,id[u],id[u]+size[u]-1,x%mod);
}
if(opt==4)
{
u=read();
printf("%d\n",ask(1,id[u],id[u]+size[u]-1));
}
}
return 0;
}