参照洛谷模板 P3384 【模板】树链剖分 题意:
给你一棵包含n个结点的树,现要求你支持以下操作:
1.x到y结点最短路径上所有节点的值都加上z;
2.求树从x到y结点最短路径上所有节点的值之和;
3.将以x为根节点的子树内所有节点值都加上z;
4.求以x为根节点的子树内所有节点值之和。
正确的思路是树链剖分套线段树。
我们可以求一个dfs序(dfn[]),这个dfs序与之前求top[x]是同步的。所以这样就保证:
每一条链上的点的dfs序是连续的,每一个点的所有子节点的编号分布在[dfn[x],dfn[x]+siz[x]-1]之间。
这样我们就可以以每个点的dfn值作为其新下标,开一棵线段树维护即可。
对于操作1:在树剖找LCA(x,y)的过程中不断对零散区间修改即可。
void LCAu(int x,int y,long long k)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
update(1,dfn[top[x]],dfn[x],k);//显然在一条链上的点的dfn值是连续的
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
update(1,dfn[y],dfn[x],k);
}
对于操作2:在树剖找LCA(x,y)的过程中不断累加零散区间的值即可。
long long LCAq(int x,int y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD;
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD;
return ans;
}
对于操作3:修改dfn值属于[dfn[x],dfn[x]+siz[x]-1]的结点的权值即可。
对于操作4:查询dfn值属于[dfn[x],dfn[x]+siz[x]-1]的结点的权值的和即可。
完整代码:
#include<cstdio>
#include<iostream>
#define ri register int
using namespace std;
const int MAXN=200020;
int n,m,s,q,u[MAXN],v[MAXN],fst[MAXN],nxt[MAXN],key,xl,xto;
long long MOD,w[MAXN],a[MAXN],num;
int fa[MAXN],deep[MAXN],siz[MAXN],cmax[MAXN],son[MAXN],top[MAXN],cur,dfn[MAXN];
int l[MAXN<<2],r[MAXN<<2];
long long sum[MAXN<<2],len[MAXN<<2],tag[MAXN<<2];
void dfs1(int x,int father,int dep)
{
fa[x]=father,deep[x]=dep,siz[x]=1;
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=father)
{
dfs1(v[k],x,dep+1);
if(siz[v[k]]>cmax[x]) cmax[x]=siz[v[k]],son[x]=v[k];
siz[x]+=siz[v[k]];
}
}
void dfs2(int x,int anc)
{
top[x]=anc,dfn[x]=++cur,a[cur]=w[x];
if(son[x]) dfs2(son[x],anc);
for(ri k=fst[x];k>0;k=nxt[k])
if(v[k]!=fa[x]&&v[k]!=son[x]) dfs2(v[k],v[k]);
}
void pushup(int p)
{
sum[p]=(sum[p <<1]+sum[p <<1|1]+MOD)%MOD;
}
void pushdown(int p)
{
sum[p <<1]=(sum[p <<1]+(len[p <<1]*tag[p])%MOD+MOD)%MOD;
tag[p <<1]=(tag[p <<1]+tag[p]+MOD)%MOD;
sum[p <<1|1]=(sum[p <<1|1]+(len[p <<1|1]*tag[p])%MOD+MOD)%MOD;
tag[p <<1|1]=(tag[p <<1|1]+tag[p]+MOD)%MOD;
tag[p]=0;
}
void build(int p,int lft,int rit)
{
l[p]=lft,r[p]=rit;
if(lft==rit)
{
sum[p]=a[lft],len[p]=1;
return;
}
int mid=(lft+rit)>>1;
build(p <<1,lft,mid);
build(p <<1|1,mid+1,rit);
pushup(p);
len[p]=len[p <<1]+len[p <<1|1];
}
void update(int p,int lft,int rit,long long k)
{
if(lft<=l[p]&&r[p]<=rit)
{
sum[p]=(sum[p]+(len[p]*k)%MOD+MOD)%MOD,tag[p]=(tag[p]+k+MOD)%MOD;
return;
}
pushdown(p);
if(lft<=r[p <<1]) update(p <<1,lft,rit,k);
if(l[p <<1|1]<=rit) update(p <<1|1,lft,rit,k);
pushup(p);
}
long long query(int p,int lft,int rit)
{
if(lft<=l[p]&&r[p]<=rit) return sum[p];
long long ans=0;
pushdown(p);
if(lft<=r[p <<1]) ans=query(p <<1,lft,rit);
if(l[p <<1|1]<=rit) ans=(ans+query(p <<1|1,lft,rit)+MOD)%MOD;
return ans;
}
void LCAu(int x,int y,long long k)
{
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
update(1,dfn[top[x]],dfn[x],k);
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
update(1,dfn[y],dfn[x],k);
}
long long LCAq(int x,int y)
{
long long ans=0;
while(top[x]!=top[y])
{
if(deep[top[x]]<deep[top[y]]) swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x])+MOD)%MOD;
x=fa[top[x]];
}
if(deep[x]<deep[y]) swap(x,y);
ans=(ans+query(1,dfn[y],dfn[x])+MOD)%MOD;
return ans;
}
int main()
{
scanf("%d%d%d%lld",&n,&q,&s,&MOD);
for(ri i=1;i<=n;i++) scanf("%lld",&w[i]);
m=(n-1)<<1;
for(ri i=1;i<=m;i+=2)
{
scanf("%d%d",&u[i],&v[i]);
nxt[i]=fst[u[i]],fst[u[i]]=i;
u[i+1]=v[i],v[i+1]=u[i];
nxt[i+1]=fst[u[i+1]],fst[u[i+1]]=i+1;
}
dfs1(s,s,0);
dfs2(s,s);
build(1,1,n);
for(ri i=1;i<=q;i++)
{
scanf("%d%d",&key,&xl);
if(key==1)
{
scanf("%d%lld",&xto,&num);
LCAu(xl,xto,num);
}
if(key==2)
{
scanf("%d",&xto);
cout<<LCAq(xl,xto)<<'\n';
}
if(key==3)
{
scanf("%lld",&num);
update(1,dfn[xl],dfn[xl]+siz[xl]-1,num);
}
if(key==4) cout<<query(1,dfn[xl],dfn[xl]+siz[xl]-1)<<'\n';
}
return 0;
}