题目链接:https://www.luogu.com.cn/problem/P3384
建树时是按照dfs序来的。
需要对子树进行操作,因为一颗子树的dfs序一定是连续的,因此也可以用线段树维护。
链的更新和子树的更新是不同的。 链的更新需要查找路径上经过的节点,让uv不断的往上搜索。子树的更新只需要计算当前根节点的dfs序和子树内最大的dfs序即tid[b]+sz[b]-1,然后直接用线段树更改tid[b]~ tid[b]+sz[b]-1这一段区间即可。
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <vector>
#include <queue>
#include <stack>
using namespace std;
const long long maxn=1e5+5;
long long n,m,r,p;
vector <long long> g[maxn];
long long sz[maxn],dep[maxn];
long long ch[maxn],fa[maxn];
long long top[maxn],tid[maxn],tid2[maxn];
long long tot,a[maxn];
void dfs1(long long u,long long f,long long d)
{
sz[u]=1;
fa[u]=f;
dep[u]=d;
for(long long i=0; i<g[u].size(); i++)
{
long long v=g[u][i];
if(v==f)
continue;
dfs1(v,u,d+1);
sz[u]+=sz[v];
if(ch[u]==-1||sz[v]>sz[ch[u]])
ch[u]=v;
}
}
void dfs2(long long u,long long tp)
{
top[u]=tp;
tid[u]=++tot;
tid2[tot]=u;
if(ch[u]==-1)
return ;
dfs2(ch[u],tp);
for(long long i=0; i<g[u].size(); i++)
{
long long v=g[u][i];
if(v!=fa[u]&&v!=ch[u])
dfs2(v,v);
}
}
struct segmenttree
{
long long sum[maxn<<2],lazy[maxn<<2];
void build(int i,int l,int r)
{
if(l==r)
{
sum[i]=a[tid2[l]];
return ;
}
int mid=(l+r)/2;
build(i*2,l,mid);
build(i*2+1,mid+1,r);
pushup(i);
}
void pushdown(long long i,long long l,long long r)
{
if(!lazy[i])
return ;
long long mid=(l+r)/2;
sum[i*2]+=(mid-l+1)*lazy[i];
sum[i*2+1]+=(r-mid)*lazy[i];
lazy[i*2]+=lazy[i];
lazy[i*2+1]+=lazy[i];
lazy[i]=0;
sum[i*2]%=p;
sum[i*2+1]%=p;
lazy[i*2]%=p;
lazy[i*2+1]%=p;
}
void pushup(long long i)
{
sum[i]=sum[i*2]+sum[i*2+1];
sum[i]%=p;
}
void update(long long i,long long l,long long r,long long L,long long R,long long val)
{
if(l>=L&&r<=R)
{
sum[i]+=(r-l+1)*val;
sum[i]%=p;
lazy[i]+=val;
lazy[i]%=p;
return ;
}
long long mid=(l+r)/2;
pushdown(i,l,r);
if(L<=mid)
update(i*2,l,mid,L,R,val);
if(R>mid)
update(i*2+1,mid+1,r,L,R,val);
pushup(i);
}
void update(long long u,long long v,long long val)
{
long long f1=top[u],f2=top[v];
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(f1,f2);
swap(u,v);
}
update(1,1,n,tid[f1],tid[u],val);
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])
swap(u,v);
update(1,1,n,tid[v],tid[u],val);
}
long long query(long long i,long long l,long long r,long long L,long long R)
{
if(l>=L&&r<=R)
{
return sum[i];
}
long long mid=(l+r)/2;
pushdown(i,l,r);
long long ans=0;
if(L<=mid)
ans+=query(i*2,l,mid,L,R);
if(R>mid)
ans+=query(i*2+1,mid+1,r,L,R);
return ans%p;
}
long long query(long long u,long long v)
{
long long f1=top[u],f2=top[v];
long long ans=0;
while(f1!=f2)
{
if(dep[f1]<dep[f2])
{
swap(u,v);
swap(f1,f2);
}
ans+=query(1,1,n,tid[f1],tid[u]);
ans%=p;
u=fa[f1];
f1=top[u];
}
if(dep[u]<dep[v])
swap(u,v);
ans+=query(1,1,n,tid[v],tid[u]);
return ans%p;
}
} st;
int main()
{
scanf("%lld%lld%lld%lld",&n,&m,&r,&p);
for(long long i=1; i<=n; i++)
{
scanf("%lld",&a[i]);
}
for(long long i=0; i<n-1; i++)
{
long long u,v;
scanf("%lld%lld",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
memset(ch,-1,sizeof(ch));
dfs1(r,0,0);
dfs2(r,r);
st.build(1,1,n);
long long a,b,c,d;
long long ans;
while(m--)
{
scanf("%lld",&a);
if(a==1)
{
scanf("%lld%lld%lld",&b,&c,&d);
st.update(b,c,d%p);
}
else if(a==2)
{
scanf("%lld%lld",&b,&c);
ans=st.query(b,c);
printf("%lld\n",ans%p);
}
else if(a==3)
{
scanf("%lld%lld",&b,&c);
d=tid2[tid[b]+sz[b]-1];
st.update(1,1,n,tid[b],tid[d],c%p);
}
else
{
scanf("%lld",&b);
d=tid2[tid[b]+sz[b]-1];
ans=st.query(1,1,n,tid[b],tid[d]);
printf("%lld\n",ans%p);
}
}
return 0;
}