题目链接:https://www.luogu.org/problemnew/show/P3384
题目大意:
如题,已知一棵包含N个结点的树(连通且无环),每个节点上包含一个数值,需要支持以下操作:
操作1: 格式: 1 x y z 表示将树从x到y结点最短路径上所有节点的值都加上z
操作2: 格式: 2 x y 表示求树从x到y结点最短路径上所有节点的值之和
操作3: 格式: 3 x z 表示将以x为根节点的子树内所有节点值都加上z
操作4: 格式: 4 x 表示求以x为根节点的子树内所有节点值之和
题目思路:树链剖分的入门题。树链剖分它实际上就是通过两次dfs得到重儿子,然后得到轻边和重边,然后不断地通过攀爬top使得他们能到一条链上。由于通过轻边往上爬最多爬logn次,所以复杂度logn*logn,通过重儿子得到dfs序可以保证同一条链的dfs序在一起,从而方便线段树操作。而对于子树,由于他只是优先跑了重儿子,但是还保留着dfs序的性质,即他的子树还是在他的后边,所以可以直接通过siz和本身的dfs序来进行子树修改
以下是代码:
#include <bits/stdc++.h>
using namespace std;
#define rep(i,a,b) for(ll i=a;i<=b;i++)
#define per(i,a,b) for(ll i=a;i>=b;i--)
#define ll long long
#define inf 0x3f3f3f3f
const ll MAXN = 1e5+5;
ll a[MAXN],fa[MAXN],dfn[MAXN],son[MAXN],dep[MAXN],siz[MAXN],rk[MAXN],top[MAXN],tot;
vector<ll>v[MAXN];
ll n,m,r,p,op,x,y,z;
struct node{
ll l,r,val,mark;
}t[MAXN<<2];
void dfs1(ll u,ll f){
fa[u]=f,siz[u]=1,son[u]=0;
dep[u]=dep[f]+1;
ll len=v[u].size();
rep(i,0,len-1){
ll to=v[u][i];
if(to==f)continue;
dfs1(to,u);
siz[u]+=siz[to];
if(siz[to]>siz[son[u]]){
son[u]=to;
}
}
}
void dfs2(ll u,ll tp){
top[u]=tp,dfn[u]=++tot,rk[tot]=u;
if(son[u])dfs2(son[u],tp);
ll len=v[u].size();
rep(i,0,len-1){
if(v[u][i]==fa[u]||v[u][i]==son[u])continue;
ll to=v[u][i];
dfs2(to,to);
}
}
void build(ll rt,ll l,ll r){
t[rt].l=l,t[rt].r=r,t[rt].mark=0;
if(l==r){
t[rt].val=a[rk[l]];
return;
}
ll mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
t[rt].val=(t[rt<<1].val+t[rt<<1|1].val)%p;
}
void spread(ll rt){
if(t[rt].mark){
t[rt<<1].val=(t[rt<<1].val+(t[rt<<1].r-t[rt<<1].l+1)*t[rt].mark)%p;
t[rt<<1|1].val=(t[rt<<1|1].val+(t[rt<<1|1].r-t[rt<<1|1].l+1)*t[rt].mark)%p;
t[rt<<1].mark=(t[rt<<1].mark+t[rt].mark)%p;
t[rt<<1|1].mark=(t[rt<<1|1].mark+t[rt].mark)%p;
t[rt].mark=0;
}
}
void update(ll rt,ll l,ll r,ll val){
if(t[rt].l>=l&&t[rt].r<=r){
t[rt].val=(t[rt].val+(t[rt].r-t[rt].l+1)*val)%p;
t[rt].mark=(t[rt].mark+val)%p;
return ;
}
spread(rt);
ll mid=(t[rt].l+t[rt].r)>>1;
if(l<=mid)update(rt<<1,l,r,val);
if(r>mid)update(rt<<1|1,l,r,val);
t[rt].val=(t[rt<<1].val+t[rt<<1|1].val)%p;
}
ll query(ll rt,ll l,ll r){
if(t[rt].l>=l&&t[rt].r<=r){
return t[rt].val;
}
spread(rt);
ll mid=(t[rt].l+t[rt].r)>>1;
ll ans=0;
if(l<=mid)ans=(ans+query(rt<<1,l,r))%p;
if(r>mid)ans=(ans+query(rt<<1|1,l,r))%p;
return ans;
}
void update1(ll x,ll y,ll z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
update(1,dfn[top[x]],dfn[x],z);
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
update(1,dfn[x],dfn[y],z);
}
ll query1(ll x,ll y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
ans=(ans+query(1,dfn[top[x]],dfn[x]))%p;
x=fa[top[x]];
}
if(dfn[x]>dfn[y])swap(x,y);
ans=(ans+query(1,dfn[x],dfn[y]))%p;
return ans;
}
int main(){
while(~scanf("%lld%lld%lld%lld",&n,&m,&r,&p)){
memset(t,0,sizeof(t));
rep(i,1,n)scanf("%lld",&a[i]),v[i].clear(),siz[i]=0;
rep(i,1,n-1){
scanf("%lld%lld",&x,&y);
v[x].push_back(y);
v[y].push_back(x);
}
tot=0;
dep[0]=0;
dfs1(r,0);
dfs2(r,r);
build(1,1,n);
rep(i,1,m){
scanf("%lld",&op);
if(op==1){
scanf("%lld%lld%lld",&x,&y,&z);
update1(x,y,z);
}
else if(op==2){
scanf("%lld%lld",&x,&y);
ll ans=query1(x,y);
printf("%lld\n",ans);
}
else if(op==3){
scanf("%lld%lld",&x,&z);
update(1,dfn[x],dfn[x]+siz[x]-1,z);
}
else if(op==4){
scanf("%lld",&x);
ll ans=query(1,dfn[x],dfn[x]+siz[x]-1);
printf("%lld\n",ans);
}
}
}
return 0;
}